【HNOI】 c tree-dp

  【题目描述】给定一个n个节点的树,每个节点有两个属性值a[i],b[i],我们可以在树中选取一个连通块G,这个连通块的值为(Σa[x])(Σb[x]) x∈G,求所有连通块的值的和,输出答案对1000000007取余。

  【数据范围】n<=10^5.

  首先我们任选一个点作为根,变成一颗有根树。观察答案为(Σa[x])(Σb[x]),那么我们可以将这个答案展开成为每一个b[x]乘上所有可能情况下的a[y],这个可能情况就是x点在连通块中时,b[x]乘上连通块内所有点的a值去和,再枚举所有的连通块,就可以求出来b[x]对答案的贡献,那么我们现在问题就转化为了求出来一个节点,所有包括这个节点的连通块的a值和,每个连通块的a值为连通块内所有点的a值和,设这个值为sum_[x]。

  我们的sum_[x]的值的求方法可以为x子树中每个a被累加的次数加上非x子树节点a值被累加的次数,那么我们可以依次求出来这两个,然后求出sum_[x]。

  我们设w[x]为以x为根的子树中,包含x节点的连通块的数量,sum[x]为以x为根的子树中,包含x的所有连通块的a值和,w_[x]为所有包含x节点的连通块的数量。

  有了这些量,我们就可以求出sum_[x],先考虑这些量的转移。

  w[x]=π(w[son of x]+1).

  sum[x]=Σ(w[x]/(w[son of x]+1)*sum[son of x]).

  这两个量的转移是由子节点到根的,比较容易考虑,现在我们有了这两个量之后,考虑用这两个量转移其余的两个量。

  w_[x]=(w_[father of x]/(w[x]+1)+1)*w[x].

  那么sum_[x]就等于之前说的两部分相加,则

  sum_[x]=w_[father of x]/(w[x]+1)+1)*sum[x]+(sum_[father of x]-w_[father of x]/(w[x]+1)*sum[x])/(w[x]+1)*w[x].

  反思:为了提高速度没开LL,用到的地方强转的LL,然后有的地方忘加了,纠结了好久= =。

//By BLADEVIL

#include <cstdio>

#define d39 1000000007

#define maxn 100010

#define LL long long



using namespace std;



int n,l;

int last[maxn],other[maxn<<1],pre[maxn<<1],a[maxn],b[maxn],que[maxn],dis[maxn];

int sum[maxn],w[maxn],sum_[maxn],w_[maxn];



void connect(int x,int y) {

    pre[++l]=last[x];

    last[x]=l;

    other[l]=y;

}



int pw(int x,int k) {

    int ans=1;

    while (k) {

        if (k&1) ans=((LL)ans*x)%d39; 

        x=((LL)x*x)%d39;

        k>>=1;

    }

    return ans;

}



int main() {

    freopen("c.in","r",stdin); freopen("c.out","w",stdout);

    scanf("%d",&n);

    for (int i=1;i<n;i++) {

        int x,y; scanf("%d%d",&x,&y);

        connect(x,y); connect(y,x);

    }

    for (int i=1;i<=n;i++) scanf("%d",&a[i]);

    for (int i=1;i<=n;i++) scanf("%d",&b[i]);

    int h=0,t=1; que[1]=1; dis[1]=1;

    while (h<t) {

        int cur=que[++h];

        for (int p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]) continue;

            que[++t]=other[p]; dis[other[p]]=dis[cur]+1;

        }

    }

    //for (int i=1;i<=n;i++) printf("%d ",que[i]); printf("\n");

    for (int i=n;i;i--) {

        int cur=que[i];

        w[cur]=1;

        for (int p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]<dis[cur]) continue;

            w[cur]=((LL)w[cur]*(w[other[p]]+1))%d39;

        }

        sum[cur]=((LL)w[cur]*a[cur])%d39;

        for (int p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]<dis[cur]) continue;

            sum[cur]=(sum[cur]+((LL)((LL)w[cur]*pw(w[other[p]]+1,d39-2)%d39)*sum[other[p]])%d39)%d39;

        }

    }

    //for (int i=1;i<=n;i++) printf("%d %d %d\n",i,sum[i],w[i]);

    for (int i=1;i<=n;i++) {

        int cur=que[i];

        if (cur==1) {

            w_[cur]=w[cur];

            sum_[cur]=sum[cur];

        }

        for (int p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]<dis[cur]) continue;        

            //printf("%d\n",pw(w_[cur]*w[other[p]]+1,d39-2)%d39);

            int tot=(LL)w_[cur]*pw(w[other[p]]+1,d39-2)%d39;

            //printf("%d\n",tot);

            w_[other[p]]=((LL)(tot+1)%d39*w[other[p]]%d39);

            sum_[other[p]]=((LL)(tot+1)*sum[other[p]]%d39+(LL)((LL)(sum_[cur]-(LL)tot*sum[other[p]]%d39+d39))%d39*w[other[p]]%d39*pw(w[other[p]]+1,d39-2)%d39)%d39;

        }

    }

    //for (int i=1;i<=n;i++) printf("%d %d %d %d\n",i,w[i],sum[i],sum_[i]);

    int ans=0;

    for (int i=1;i<=n;i++) ans=(ans+(LL)sum_[i]*b[i])%d39;

    printf("%d\n",ans);

    fclose(stdin); fclose(stdout);

    return 0;

}

 

你可能感兴趣的:(tree)