圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径(p1,p2)和(p2,p3)都属于他,那么路径(p1,p3)也属于他,反之如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。
小薰表示她看一眼就知道这道题怎么做了。你会吗?
这题最重要的就是推出一个式子。
我们正难则反。用 n3 减去所有三条边相等的情况。
设in0,out0,in1,out1分别表示一个节点连进来的路径mod y=x的个数,和mod y≠x的个数;和从这个点连出去的路径mod y=x的个数,和mod y≠x的个数。
ans=n3−∑in0[i]∗out0[i]∗2+in0∗out1[i]+in1[i]∗out0[i]+in1[i]∗out1[i]∗22
这个东西随便退一下就好了。
然后我们知道in0[i]=n-in1[i],out0[i]=n-out1[i]
那么只用处理处in0和out0就可以了。
这个明显是用树分治来处理。
那么我们现在已知i的儿子(j,i)和(i,k)的值,我们要求(j,k),设(j,i)=a,(j,i)的长度为len,(i,k)为b。
因为要在mod y意义下=x,所以 a+b∗klen=x ,那么转化一下 b=x−aklen
那么我们处理出a的数组,然后排序,求出 x−aklen ,然后在排序后的b数组中二分查找,求出关于a的in,那么对应b的区间[l,r]的out都要加一,那么在l上打一个+1表示,r+1上打一个-1标记。
还要处理一个问题,就是从j节点上来,然后又从j节点下去,那么直接把j的这棵子树的a和b单独做一下然后减掉就好了。
处理重心就不用但是时间了。
#include
#include
#include
#include
#include
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fod(i,a,b) for(i=a;i>=b;i--)
#define rep(i,a) for(i=first[a];i;i=next[i])
using namespace std;
typedef long long ll;
const int maxn=100007;
int i,j,k,l,t,m,x,y,mo,tot,tot1;
int first[maxn*2],last[maxn*2],next[maxn*2],num,zhi[maxn];
int f[maxn][21],deep[maxn],size[maxn],pan,z;
int c[maxn];
ll ni[maxn],ci[maxn],in[maxn],out[maxn],ans,ans1,n;
bool az[maxn];
struct node{
int o,len;
ll a;
}a[maxn],b[maxn],d[maxn],e[maxn];
bool cmp(node x,node y){return x.a<y.a;}
void add(int x,int y){last[++num]=y,next[num]=first[x],first[x]=num;}
ll qsm(ll x,ll y){
ll z=1;
for(;y;y/=2,x=x*x%mo)if(y&1)z=z*x%mo;
return z;
}
void gsize(int x,int y){
int i;
size[x]=1;
rep(i,x){
if(last[i]!=y&&!az[last[i]]){
gsize(last[i],x);
size[x]+=size[last[i]];
}
}
}
void zhaozhong(int x,int y){
int i;bool u=1;
rep(i,x){
if(last[i]!=y&&!az[last[i]]){
zhaozhong(last[i],x);
if(size[last[i]]>pan/2)u=0;
}
}
if(pan-size[x]>pan/2)u=0;
if(u)z=x;
}
void dfs(int x,int y,ll z,int dep,ll u){
int i;
z=(z+zhi[x]*ci[dep-2])%mo;tot++;tot1++;
a[tot].a=(u*k+zhi[x])%mo;b[tot].a=z;
a[tot].len=b[tot].len=dep;a[tot].o=b[tot].o=x;
d[tot1].a=(u*k+zhi[x])%mo;e[tot1].a=z;
d[tot1].len=e[tot1].len=dep;d[tot1].o=e[tot1].o=x;
int p=a[tot].a;
rep(i,x){
if(last[i]!=y&&!az[last[i]]){
dfs(last[i],x,z,dep+1,p);
}
}
}
void solve(){
// memset(c,0,sizeof(c));
int i,l,r,mid,o,yi,er;b[0].a=-1,b[tot+1].a=mo+1;
fo(i,1,tot){
o=(x-a[i].a+mo)*ni[a[i].len]%mo;
l=0,r=tot+1;
while(l2;
if(b[mid].a>=o)r=mid;else l=mid+1;
}
yi=l;
l=0,r=tot+1;
while(l1)/2;
if(b[mid].a>o)r=mid-1;else l=mid;
}
er=l;
if(yi>er)continue;
out[a[i].o]+=er-yi+1;
c[yi]++,c[er+1]--;
}
o=0;c[tot+1]=0;c[tot+2]=0;
fo(i,1,tot){
o+=c[i];
in[b[i].o]+=o;
c[i]=0;
}
}
void solve1(){
// memset(c,0,sizeof(c));
int i,l,r,mid,yi,er;e[0].a=-1,e[tot1+1].a=mo+1;
ll o;
fo(i,1,tot1){
o=(x-d[i].a+mo)*ni[d[i].len]%mo;
l=0,r=tot1+1;
while(l2;
if(e[mid].a>=o)r=mid;else l=mid+1;
}
yi=l;
l=0,r=tot1+1;
while(l1)/2;
if(e[mid].a>o)r=mid-1;else l=mid;
}
er=l;
if(yi>er)continue;
out[d[i].o]-=er-yi+1;
c[yi]++,c[er+1]--;
}
o=0;c[tot1+1]=0;c[tot1+2]=0;
fo(i,1,tot1){
o+=c[i];
in[e[i].o]-=o;
c[i]=0;
}
}
void fen(int x,int y){
int i;
gsize(x,y);
pan=size[x];
zhaozhong(x,y);az[z]=1;tot=0;
a[++tot].a=zhi[z]%mo,a[tot].len=1,a[tot].o=z;
b[tot].a=0,b[tot].len=1,b[tot].o=z;
rep(i,z){
if(!az[last[i]]){
tot1=0;
dfs(last[i],z,0,2,zhi[z]);
sort(d+1,d+1+tot1,cmp);
sort(e+1,e+1+tot1,cmp);
solve1();
}
}
sort(a+1,a+1+tot,cmp);
sort(b+1,b+1+tot,cmp);
solve();
rep(i,z){
if(!az[last[i]]){
fen(last[i],z);
}
}
}
int main(){
//freopen("fan.in","r",stdin);
scanf("%d%d%d%d",&n,&mo,&k,&x);ans=n*n*n;
ci[0]=1;fo(i,1,n)ci[i]=ci[i-1]*k%mo,ni[i]=qsm(ci[i],mo-2);
fo(i,1,n)scanf("%d",&zhi[i]);
fo(i,1,n-1){
scanf("%d%d",&l,&t);
add(l,t),add(t,l);
}
fen(1,0);
fo(i,1,n){
ans1+=in[i]*(n-out[i])+in[i]*(n-in[i])*2+(n-out[i])*out[i]*2+(n-out[i])*in[i];
}
ans1=ans1/2;
printf("%lld\n",ans-ans1);
}