给定一颗n个结点的无根树,树上的每个点有一个非负整数点权,定义一条路径的价值为路径上的点权和-路径的点权最大值。
给定参数p,我们想知道,有多少不同的树上简单路径,满足它的价值恰好是p的倍数。
注意:单点算作一个路径;u ≠ v时,(u,v)和(v,u)只算一次。
首先套上树分治模板,再想想怎么做……
可以发现,若是想满足值为p的倍数,那就意味着需要该值mod p=0,问题来了,如何统计一个子树的答案呢?其实很容易解决,可以发现,当一条链的贡献是x时,它需要另一条贡献为(p-x)的链结合(在模p意义下),那么我们可以开一个桶来储存,对于减去的最大值,我们只需要把每一条链按最大值排个序,就能够保证每一次减去的最大值是当前链的最大值。还有要记得去重的问题,因为每一棵子树都会被算多,在做完当前树之后要减去每棵子树的贡献。时间复杂度为(n log n)。
#include
#include
#include
#include
#include
using namespace std;
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define rep(i,x) for(i=la[x];i;i=ne[i])
typedef long long ll;
const int N=1e5+5,MX=1e7+5;
struct arr{
ll s;int mx;
}t[N];
int la[N],ne[N*2],da[N*2],node[N],T[MX],D[N],fa[N],a[N];
int n,mo,i,x,y,sum,num,hea;
ll ans;
bool p[N],bz[N];
void ins(int x,int y){
da[++sum]=y,ne[sum]=la[x],la[x]=sum;
da[++sum]=x,ne[sum]=la[y],la[y]=sum;
}
bool cmp(arr x,arr y){return x.mxvoid gnode(int x){
int i,l=0,r=1;D[1]=x;fa[x]=0;
while(lif(da[i]!=fa[x]&&!p[da[i]]) D[++r]=da[i],fa[da[i]]=x;
}l++;
while(l>1){
x=D[--l];node[x]=1;
rep(i,x) if(da[i]!=fa[x]&&!p[da[i]]) node[x]+=node[da[i]];
}
}
void gheav(int x,int size){
int i,l=0,r=1;bz[x]=1;D[1]=x;fa[x]=0;
while(lif(da[i]!=fa[x]&&!p[da[i]]){
if(node[da[i]]>size/2) bz[x]=0;
D[++r]=da[i];bz[da[i]]=1;fa[da[i]]=x;
}
}l++;
while(l>1){
x=D[--l];if(bz[x]&&size-node[x]<=size/2) hea=x;
}
}
void dfs(int x,int fa,ll sum,int mx){
t[++num]=(arr){sum,mx};
int i;rep(i,x) if(da[i]!=fa&&!p[da[i]]) dfs(da[i],x,sum+a[da[i]],max(mx,a[da[i]]));
}
void deal(int tot,int x,ll sum,int mx,int zf){
num=0;dfs(x,0,sum,mx);
sort(t+1,t+num+1,cmp);
fo(i,1,num){
ans+=T[(t[i].s-t[i].mx)%mo]*zf;
T[((mo-t[i].s+tot)%mo+mo)%mo]++;
}
fo(i,1,num)T[((mo-t[i].s+tot)%mo+mo)%mo]--;
}
void divi(int x){
int i;
gnode(x);hea=0;gheav(x,node[x]);x=hea;
p[x]=1;deal(a[x],x,a[x],a[x],1);
rep(i,x) if(!p[da[i]]){
deal(a[x],da[i],a[x]+a[da[i]],max(a[x],a[da[i]]),-1);
divi(da[i]);
}
}
int main(){
freopen("path.in","r",stdin);
freopen("path.out","w",stdout);
scanf("%d%d",&n,&mo);
fo(i,2,n) scanf("%d%d",&x,&y),ins(x,y);
fo(i,1,n) scanf("%d",&a[i]);
divi(1);ans+=n;
printf("%lld",ans);
}