CF 434E 圣诞树(tree)

先上题目:

Description

圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径(p1,p2)和(p2,p3)都属于他,那么路径(p1,p3)也属于他,反之如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。
小薰表示她看一眼就知道这道题怎么做了。你会吗?

Input

第一行包含四个整数n,y,k和x,其中n为圣诞树的结点数,y,k和x的含义如题目所示,题目保证y是一个质数。
第二行包含n个整数,第i个整数vi表示第i个结点的权值。
接下来n-1行,每行包含2个整数,表示树上的一条边。树的结点从1到n编号。

Output

包含一个整数,表示有多少整数组(p1,p2,p3)满足题目描述的性质。

Sample Input

输入1:

1 2 1 0
1

输入2:

3 5 2 1
4 3 1
1 2
2 3

输入3:

8 13 8 12
0 12 7 4 12 0 8 12
1 8
8 4
4 6
6 2
2 3
8 5
2 7

Sample Output

输出1:

1

输出2:

14

输出3:

341

Data Constraint

对于20%的数据,n ≤ 200;
对于50%的数据,n ≤ 10^4;
对于100%的数据,1 ≤ n ≤ 10^5,2 ≤ y ≤ 10^9,1 ≤ k ≤ y,0 ≤ x < y。

题目大意:

给你一颗大小为100000的树;
设S(x,y)路径上的点是d1,d2,…,dp
定义G (S(x,y))=sum(di*k^i)( mod y(质数) )
f(x,y)= G(S(x,y)) = x(mod y) 等于是1,不等于是0
问有多少个三元组(a,b,c),满足f(a,b) = f(b,c) = f(a,c)

解析:
我们想计算满足(i,j),(j,k),(i,k)权值都为 0 或都为 1 的三元组(i,j,k)个数。它等于三条边权值相同的有向三角形个数。

这样计算可能有一些困难。我们考虑三条边权值不全相同的有向三角形个数。

我们定义 in0[i]表示进入 i 的边中权值为 0 的个数。类似地定义 in1[i],out0[i],out1[i]。

令p=sum(in0[i]*in1[i]*2+in0[i]*out1[i]+in1[i]*out0[i]+out0[i]*out1[i]*2)

则Ans=n^3-p

我们当然可以暴力,只有50分。
我们还可以点分治。

让我们来考虑如何优化。我们可以使用树的分治来在 O(nlog2n)的时间复杂度下解决
此题。选择根 i,计算它的子树。我们可以得到子树中所有结点到 i 的权值。我们可以
先保存这些权值和路径长度。

从结点j出发的一条路径有权值v与长度L。我们想找到k使得G(S(j,k)) ≡ X(mod Y)。
令 H(i,j)为序列(i,j)除掉 i 的权值。

则 G(S(j,k))= G(S(j,i))+G(H(i,k))·KL = v+ G(H(i,k))·KL
所以有 G(H(i,k))=(X-v)/KL,因为 Y 是一个质数,所以我们就可以很容易地计算出z=(X-v)/KL。
现在这个问题变成了我们需要计算有多少条从 i 出发的不包含 i 的路径权值为 z。所
以我们可以用二分搜索的办法在排好序的数组中查询。我们可以用这种方法算出 in0 与 out0。

Code:

#include
#include
#define ll long long
#define fo(i,x,y) for(ll i=x;i<=y;i++)
#define max(a,b) ((a)>(b)?(a):(b))
using namespace std;

const ll maxn=100005;

ll n,mo,k,xx,ans,root,v[maxn],kf[maxn],kfn[maxn],bz[maxn];

ll msiz[maxn],siz[maxn],p1[maxn],p2[maxn],l[maxn], d[maxn], c[maxn];

ll out0[maxn],in0[maxn];

ll tot,final[maxn],next[maxn*2],to[maxn*2];

void link(ll x,ll y) {
    next[++tot]=final[x], to[tot]=y , final[x]=tot;
    next[++tot]=final[y], to[tot]=x , final[y]=tot;
}

ll ksm(ll x,ll y) {
    ll s=1;
    for(;y;) {
        if(y&1)  s=(s*x)%mo;
        x=(x*x)%mo; y>>=1;
    }
    return s;
}

void froot(ll x) {
    bz[x]=1;
        siz[x]=1;  msiz[x]=0;
        for(ll k=final[x];k;k=next[k]) {
            ll y=to[k]; if(bz[y]) continue;
            froot(y);
            msiz[x]=max(msiz[x],siz[y]);
            siz[x]+=siz[y];
        }
        msiz[x]=max(msiz[x],siz[0]-siz[x]);
        root=msiz[root]x]?root:x;
    bz[x]=0;
}

void mkt(ll x) {
    bz[x]=1;
        siz[x]=1;
        for(ll k=final[x];k;k=next[k]) {
            ll y=to[k]; if(bz[y]) continue;
            l[y]=l[x]+1;
            p1[y]=(p1[x]+v[y]*kf[l[x]])%mo;
            p2[y]=(p2[x]*kf[1]+v[y])%mo;
            mkt(y);
            siz[x]+=siz[y];
        }
        d[++d[0]]=x;
    bz[x]=0;
}

ll find(ll c[],ll len,ll vc) {
    ll l1=0, r1=-1;
    for(ll l=1, r=len; l <= r;) {
        ll mid=(l+r)/2;
        if(c[mid] < vc) l=mid+1;
        if(c[mid] > vc) r=mid-1;
        if(c[mid] == vc) l1=mid, r=mid-1;
    }
    for(ll l=1, r=len; l <= r;) {
        ll mid=(l+r)/2;
        if(c[mid] < vc) l=mid+1;
        if(c[mid] > vc) r=mid-1;
        if(c[mid] == vc) r1=mid, l=mid+1;
    }
    return r1-l1+1;
}

void mkt2(int x) {
    bz[x]=1;
        for(ll k=final[x];k;k=next[k]) {
            ll y=to[k]; if(bz[y]) continue;
            mkt(y);
        }
        d[++d[0]]=x;
    bz[x]=0;
}

void solve(int ff) {
    fo(i,1,d[0]) c[i]=p1[d[i]];
    sort(c+1,c+d[0]+1);
    fo(i,1,d[0]) {
        ll vc=((xx-p2[d[i]]+mo)*(kfn[l[d[i]]+1]))%mo;
        out0[d[i]]+=ff*find(c,d[0],vc);
    }

    fo(i,1,d[0]) c[i]=((xx-p2[d[i]]+mo)*kfn[l[d[i]]+1])%mo;
    sort(c+1,c+d[0]+1);
    fo(i,1,d[0]) {
        ll vc=p1[d[i]];
        in0[d[i]]+=ff*find(c,d[0],vc);
    }
}

void dg(ll x) {
    root=0; siz[0]=siz[x]; froot(x);
    l[root]=0; p2[root]=v[root]; p1[root]=0; d[0]=0; mkt(root);
    solve(1);
    bz[root]=1;
    for(ll k=final[root];k;k=next[k]) {
        ll y=to[k]; if(bz[y]) continue;
        d[0]=0; mkt2(y);
        solve(-1);
    }
    for(ll k=final[root];k;k=next[k]) {
        ll y=to[k]; if(bz[y]) continue;
        dg(y);
    }
}

int main() {
    scanf("%lld %lld %lld %lld", &n, &mo, &k, &xx);
    kf[0]=kfn[0]=1; fo(i,1,n) kf[i]=(kf[i-1]*k)%mo, kfn[i]=ksm(kf[i],mo-2);
    fo(i,1,n) scanf("%lld", &v[i]),v[i]%=mo;
    fo(i,1,n-1) {
        ll x,y; scanf("%lld %lld", &x, &y);
        link(x,y);
    }
    siz[1]=msiz[0]=n; dg(1);
    fo(i,1,n) ans+=2*out0[i]*(n-out0[i])+out0[i]*(n-in0[i])+(n-out0[i])*in0[i]+2*(n-in0[i])*in0[i];
    ans/=2; ans=n*n*n-ans;
    printf("%lld",ans);
}

你可能感兴趣的:(树分治)