【题目描述】给定一个n个节点的树,每个节点有两个属性值a[i],b[i],我们可以在树中选取一个连通块G,这个连通块的值为(Σa[x])(Σb[x]) x∈G,求所有连通块的值的和,输出答案对1000000007取余。
【数据范围】n<=10^5.
首先我们任选一个点作为根,变成一颗有根树。观察答案为(Σa[x])(Σb[x]),那么我们可以将这个答案展开成为每一个b[x]乘上所有可能情况下的a[y],这个可能情况就是x点在连通块中时,b[x]乘上连通块内所有点的a值去和,再枚举所有的连通块,就可以求出来b[x]对答案的贡献,那么我们现在问题就转化为了求出来一个节点,所有包括这个节点的连通块的a值和,每个连通块的a值为连通块内所有点的a值和,设这个值为sum_[x]。
我们的sum_[x]的值的求方法可以为x子树中每个a被累加的次数加上非x子树节点a值被累加的次数,那么我们可以依次求出来这两个,然后求出sum_[x]。
我们设w[x]为以x为根的子树中,包含x节点的连通块的数量,sum[x]为以x为根的子树中,包含x的所有连通块的a值和,w_[x]为所有包含x节点的连通块的数量。
有了这些量,我们就可以求出sum_[x],先考虑这些量的转移。
w[x]=π(w[son of x]+1).
sum[x]=Σ(w[x]/(w[son of x]+1)*sum[son of x]).
这两个量的转移是由子节点到根的,比较容易考虑,现在我们有了这两个量之后,考虑用这两个量转移其余的两个量。
w_[x]=(w_[father of x]/(w[x]+1)+1)*w[x].
那么sum_[x]就等于之前说的两部分相加,则
sum_[x]=w_[father of x]/(w[x]+1)+1)*sum[x]+(sum_[father of x]-w_[father of x]/(w[x]+1)*sum[x])/(w[x]+1)*w[x].
反思:为了提高速度没开LL,用到的地方强转的LL,然后有的地方忘加了,纠结了好久= =。
//By BLADEVIL #include <cstdio> #define d39 1000000007 #define maxn 100010 #define LL long long using namespace std; int n,l; int last[maxn],other[maxn<<1],pre[maxn<<1],a[maxn],b[maxn],que[maxn],dis[maxn]; int sum[maxn],w[maxn],sum_[maxn],w_[maxn]; void connect(int x,int y) { pre[++l]=last[x]; last[x]=l; other[l]=y; } int pw(int x,int k) { int ans=1; while (k) { if (k&1) ans=((LL)ans*x)%d39; x=((LL)x*x)%d39; k>>=1; } return ans; } int main() { freopen("c.in","r",stdin); freopen("c.out","w",stdout); scanf("%d",&n); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); connect(x,y); connect(y,x); } for (int i=1;i<=n;i++) scanf("%d",&a[i]); for (int i=1;i<=n;i++) scanf("%d",&b[i]); int h=0,t=1; que[1]=1; dis[1]=1; while (h<t) { int cur=que[++h]; for (int p=last[cur];p;p=pre[p]) { if (dis[other[p]]) continue; que[++t]=other[p]; dis[other[p]]=dis[cur]+1; } } //for (int i=1;i<=n;i++) printf("%d ",que[i]); printf("\n"); for (int i=n;i;i--) { int cur=que[i]; w[cur]=1; for (int p=last[cur];p;p=pre[p]) { if (dis[other[p]]<dis[cur]) continue; w[cur]=((LL)w[cur]*(w[other[p]]+1))%d39; } sum[cur]=((LL)w[cur]*a[cur])%d39; for (int p=last[cur];p;p=pre[p]) { if (dis[other[p]]<dis[cur]) continue; sum[cur]=(sum[cur]+((LL)((LL)w[cur]*pw(w[other[p]]+1,d39-2)%d39)*sum[other[p]])%d39)%d39; } } //for (int i=1;i<=n;i++) printf("%d %d %d\n",i,sum[i],w[i]); for (int i=1;i<=n;i++) { int cur=que[i]; if (cur==1) { w_[cur]=w[cur]; sum_[cur]=sum[cur]; } for (int p=last[cur];p;p=pre[p]) { if (dis[other[p]]<dis[cur]) continue; //printf("%d\n",pw(w_[cur]*w[other[p]]+1,d39-2)%d39); int tot=(LL)w_[cur]*pw(w[other[p]]+1,d39-2)%d39; //printf("%d\n",tot); w_[other[p]]=((LL)(tot+1)%d39*w[other[p]]%d39); sum_[other[p]]=((LL)(tot+1)*sum[other[p]]%d39+(LL)((LL)(sum_[cur]-(LL)tot*sum[other[p]]%d39+d39))%d39*w[other[p]]%d39*pw(w[other[p]]+1,d39-2)%d39)%d39; } } //for (int i=1;i<=n;i++) printf("%d %d %d %d\n",i,w[i],sum[i],sum_[i]); int ans=0; for (int i=1;i<=n;i++) ans=(ans+(LL)sum_[i]*b[i])%d39; printf("%d\n",ans); fclose(stdin); fclose(stdout); return 0; }