A君住在魔法森林里,魔法森林可以看做一棵n个结点的树,结点从1~n编号。树中的每个结点上都生长着蘑菇。蘑菇有许多不同的种类,但同一个结点上的蘑菇都是同一种类,更具体地,i号结点上生长着种类为c[i]的蘑菇。
现在A君打算出去采蘑菇,但他并不知道哪里的蘑菇更好,因此他选定起点s后会等概率随机选择树中的某个结点t作为终点,之后从s沿着(s,t)间的最短路径走到t.并且A君会采摘途中所经过的所有结点上的蘑菇。
现在A君想知道,对于每一个结点u,假如他从这个结点出发,他最后能采摘到的蘑菇种类数的期望是多少。为了方便,你告诉A君答案*n的值即可。
考虑点分治。
还是可以维护,自行讨论,不难
稍复杂一点的点剖
#include
#include
#include
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
typedef long long ll;
const int maxn=300000+10;
int col[maxn],c[maxn],a[maxn],size[maxn],cnt[maxn],d[maxn],belong[maxn];
bool bz[maxn],pd[maxn];
int h[maxn],go[maxn*2],next[maxn*2],sta[80];
ll ans[maxn],num;
int i,j,k,l,t,n,m,tot,top;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y){
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void write(ll x){
if (!x){
putchar('0');
putchar('\n');
return;
}
top=0;
while (x){
sta[++top]=x%10;
x/=10;
}
while (top) putchar('0'+sta[top--]);
putchar('\n');
}
void travel(int x,int y){
a[++top]=x;
int t=h[x];
size[x]=1;
while (t){
if (!bz[go[t]]&&go[t]!=y) {
travel(go[t],x);
size[x]+=size[go[t]];
}
t=next[t];
}
}
void dg(int x,int y,int z){
a[++top]=x;
belong[x]=z;
cnt[c[x]]++;
d[x]=d[y];
if (cnt[c[x]]==1){
pd[x]=1;
d[x]++;
}
else pd[x]=0;
int t=h[x];
size[x]=1;
while (t){
if (!bz[go[t]]&&go[t]!=y){
dg(go[t],x,z);
size[x]+=size[go[t]];
}
t=next[t];
}
cnt[c[x]]--;
}
void calc(int x,int y,int z){
if (pd[x]) z+=col[c[x]];
ans[x]-=(ll)z;
int t=h[x];
while (t){
if (!bz[go[t]]&&go[t]!=y) calc(go[t],x,z);
t=next[t];
}
if (pd[x]) z-=col[c[x]];
}
void solve(int x){
top=0;
travel(x,0);
int i,j=x,k=0,r,t;
while (1){
t=h[j];
while (t){
if (!bz[go[t]]&&go[t]!=k&&size[go[t]]>top/2){
k=j;
j=go[t];
break;
}
t=next[t];
}
if (!t) break;
}
top=0;
a[++top]=j;
cnt[c[j]]=1;
d[j]=1;
t=h[j];
while (t){
if (!bz[go[t]]) dg(go[t],j,go[t]);
t=next[t];
}
cnt[c[j]]=0;
size[j]=1;
i=2;
num=d[j];
while (i<=top){
k=i;
while (k1]]==belong[a[i]]) k++;
fo(r,i,k){
ans[a[r]]+=num;
ans[a[r]]+=(ll)size[j]*(d[a[r]]-1);
}
calc(belong[a[i]],j,0);
fo(r,i,k)
if (pd[a[r]]) col[c[a[r]]]+=size[a[r]];
fo(r,i,k) num+=d[a[r]];
size[j]+=(k-i+1);
i=k+1;
}
fo(i,2,top) col[c[a[i]]]=0;
reverse(a+2,a+top+1);
size[j]=1;
i=2;
num=d[j];
while (i<=top){
k=i;
while (k1]]==belong[a[i]]) k++;
fo(r,i,k){
ans[a[r]]+=num;
ans[a[r]]+=(ll)size[j]*(d[a[r]]-1);
}
calc(belong[a[i]],j,0);
fo(r,i,k)
if (pd[a[r]]) col[c[a[r]]]+=size[a[r]];
fo(r,i,k) num+=d[a[r]];
size[j]+=(k-i+1);
i=k+1;
}
fo(i,1,top) ans[j]+=(ll)d[a[i]];
fo(i,2,top) ans[a[i]]-=(ll)d[a[i]];
fo(i,2,top) col[c[a[i]]]=0;
bz[j]=1;
t=h[j];
while (t){
if (!bz[go[t]]) solve(go[t]);
t=next[t];
}
}
int main(){
freopen("mushroom.in","r",stdin);freopen("mushroom.out","w",stdout);
srand(19890604);
n=read();
fo(i,1,n) c[i]=read();
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
//solve(1);
solve(n/2);
fo(i,1,n) write(ans[i]);
}