先上题目:
圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径(p1,p2)和(p2,p3)都属于他,那么路径(p1,p3)也属于他,反之如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。
小薰表示她看一眼就知道这道题怎么做了。你会吗?
第一行包含四个整数n,y,k和x,其中n为圣诞树的结点数,y,k和x的含义如题目所示,题目保证y是一个质数。
第二行包含n个整数,第i个整数vi表示第i个结点的权值。
接下来n-1行,每行包含2个整数,表示树上的一条边。树的结点从1到n编号。
包含一个整数,表示有多少整数组(p1,p2,p3)满足题目描述的性质。
1 2 1 0
1
3 5 2 1
4 3 1
1 2
2 3
8 13 8 12
0 12 7 4 12 0 8 12
1 8
8 4
4 6
6 2
2 3
8 5
2 7
1
14
341
对于20%的数据,n ≤ 200;
对于50%的数据,n ≤ 10^4;
对于100%的数据,1 ≤ n ≤ 10^5,2 ≤ y ≤ 10^9,1 ≤ k ≤ y,0 ≤ x < y。
题目大意:
给你一颗大小为100000的树;
设S(x,y)路径上的点是d1,d2,…,dp
定义G (S(x,y))=sum(di*k^i)( mod y(质数) )
f(x,y)= G(S(x,y)) = x(mod y) 等于是1,不等于是0
问有多少个三元组(a,b,c),满足f(a,b) = f(b,c) = f(a,c)
解析:
我们想计算满足(i,j),(j,k),(i,k)权值都为 0 或都为 1 的三元组(i,j,k)个数。它等于三条边权值相同的有向三角形个数。
这样计算可能有一些困难。我们考虑三条边权值不全相同的有向三角形个数。
我们定义 in0[i]表示进入 i 的边中权值为 0 的个数。类似地定义 in1[i],out0[i],out1[i]。
令p=sum(in0[i]*in1[i]*2+in0[i]*out1[i]+in1[i]*out0[i]+out0[i]*out1[i]*2)
则Ans=n^3-p
我们当然可以暴力,只有50分。
我们还可以点分治。
让我们来考虑如何优化。我们可以使用树的分治来在 O(nlog2n)的时间复杂度下解决
此题。选择根 i,计算它的子树。我们可以得到子树中所有结点到 i 的权值。我们可以
先保存这些权值和路径长度。
从结点j出发的一条路径有权值v与长度L。我们想找到k使得G(S(j,k)) ≡ X(mod Y)。
令 H(i,j)为序列(i,j)除掉 i 的权值。
则 G(S(j,k))= G(S(j,i))+G(H(i,k))·KL = v+ G(H(i,k))·KL
所以有 G(H(i,k))=(X-v)/KL,因为 Y 是一个质数,所以我们就可以很容易地计算出z=(X-v)/KL。
现在这个问题变成了我们需要计算有多少条从 i 出发的不包含 i 的路径权值为 z。所
以我们可以用二分搜索的办法在排好序的数组中查询。我们可以用这种方法算出 in0 与 out0。
Code:
#include
#include
#define ll long long
#define fo(i,x,y) for(ll i=x;i<=y;i++)
#define max(a,b) ((a)>(b)?(a):(b))
using namespace std;
const ll maxn=100005;
ll n,mo,k,xx,ans,root,v[maxn],kf[maxn],kfn[maxn],bz[maxn];
ll msiz[maxn],siz[maxn],p1[maxn],p2[maxn],l[maxn], d[maxn], c[maxn];
ll out0[maxn],in0[maxn];
ll tot,final[maxn],next[maxn*2],to[maxn*2];
void link(ll x,ll y) {
next[++tot]=final[x], to[tot]=y , final[x]=tot;
next[++tot]=final[y], to[tot]=x , final[y]=tot;
}
ll ksm(ll x,ll y) {
ll s=1;
for(;y;) {
if(y&1) s=(s*x)%mo;
x=(x*x)%mo; y>>=1;
}
return s;
}
void froot(ll x) {
bz[x]=1;
siz[x]=1; msiz[x]=0;
for(ll k=final[x];k;k=next[k]) {
ll y=to[k]; if(bz[y]) continue;
froot(y);
msiz[x]=max(msiz[x],siz[y]);
siz[x]+=siz[y];
}
msiz[x]=max(msiz[x],siz[0]-siz[x]);
root=msiz[root]x]?root:x;
bz[x]=0;
}
void mkt(ll x) {
bz[x]=1;
siz[x]=1;
for(ll k=final[x];k;k=next[k]) {
ll y=to[k]; if(bz[y]) continue;
l[y]=l[x]+1;
p1[y]=(p1[x]+v[y]*kf[l[x]])%mo;
p2[y]=(p2[x]*kf[1]+v[y])%mo;
mkt(y);
siz[x]+=siz[y];
}
d[++d[0]]=x;
bz[x]=0;
}
ll find(ll c[],ll len,ll vc) {
ll l1=0, r1=-1;
for(ll l=1, r=len; l <= r;) {
ll mid=(l+r)/2;
if(c[mid] < vc) l=mid+1;
if(c[mid] > vc) r=mid-1;
if(c[mid] == vc) l1=mid, r=mid-1;
}
for(ll l=1, r=len; l <= r;) {
ll mid=(l+r)/2;
if(c[mid] < vc) l=mid+1;
if(c[mid] > vc) r=mid-1;
if(c[mid] == vc) r1=mid, l=mid+1;
}
return r1-l1+1;
}
void mkt2(int x) {
bz[x]=1;
for(ll k=final[x];k;k=next[k]) {
ll y=to[k]; if(bz[y]) continue;
mkt(y);
}
d[++d[0]]=x;
bz[x]=0;
}
void solve(int ff) {
fo(i,1,d[0]) c[i]=p1[d[i]];
sort(c+1,c+d[0]+1);
fo(i,1,d[0]) {
ll vc=((xx-p2[d[i]]+mo)*(kfn[l[d[i]]+1]))%mo;
out0[d[i]]+=ff*find(c,d[0],vc);
}
fo(i,1,d[0]) c[i]=((xx-p2[d[i]]+mo)*kfn[l[d[i]]+1])%mo;
sort(c+1,c+d[0]+1);
fo(i,1,d[0]) {
ll vc=p1[d[i]];
in0[d[i]]+=ff*find(c,d[0],vc);
}
}
void dg(ll x) {
root=0; siz[0]=siz[x]; froot(x);
l[root]=0; p2[root]=v[root]; p1[root]=0; d[0]=0; mkt(root);
solve(1);
bz[root]=1;
for(ll k=final[root];k;k=next[k]) {
ll y=to[k]; if(bz[y]) continue;
d[0]=0; mkt2(y);
solve(-1);
}
for(ll k=final[root];k;k=next[k]) {
ll y=to[k]; if(bz[y]) continue;
dg(y);
}
}
int main() {
scanf("%lld %lld %lld %lld", &n, &mo, &k, &xx);
kf[0]=kfn[0]=1; fo(i,1,n) kf[i]=(kf[i-1]*k)%mo, kfn[i]=ksm(kf[i],mo-2);
fo(i,1,n) scanf("%lld", &v[i]),v[i]%=mo;
fo(i,1,n-1) {
ll x,y; scanf("%lld %lld", &x, &y);
link(x,y);
}
siz[1]=msiz[0]=n; dg(1);
fo(i,1,n) ans+=2*out0[i]*(n-out0[i])+out0[i]*(n-in0[i])+(n-out0[i])*in0[i]+2*(n-in0[i])*in0[i];
ans/=2; ans=n*n*n-ans;
printf("%lld",ans);
}