这次写不容斥的版本,WA了好几次,又改成容斥的,还是没过,一怒之下把所有的int改成longlong就过了。。。
#include<iostream> #include<cstdio> #include<cstring> #include<cstdlib> #include<algorithm> #include<map> #define REP(i,a,b) for(int i=a;i<=b;i++) #define MS0(a) memset(a,0,sizeof(a)) using namespace std; typedef long long ll; const int maxn=1000100; const int INF=1e9+10; ll N,K; ll p[maxn]; ll val[maxn]; ll u,v; ll e[maxn],tot; ll first[maxn],Next[maxn]; bool vis[maxn]; ll rt,balance; map<ll,ll> id; ll d[maxn],dn; ll s[maxn]; void Init() { tot=0; memset(first,-1,sizeof(first)); } void addedge(ll u,ll v) { e[++tot]=v; Next[tot]=first[u]; first[u]=tot; } ll qpow(ll n,ll k) { ll res=1; while(k){ if(k&1) res*=n; n*=n; k>>=1; } return res; } ll toCube(ll x) { ll res=0; REP(i,0,K-1){ ll cur=0; while(x%p[i]==0){ cur++; x/=p[i]; } res+=qpow(3,i)*(cur%3); } return res; } ll add3(ll a,ll b) { ll c=0,x=0,y=0; ll t=1; REP(i,0,K-1){ x=a%3;a/=3; y=b%3;b/=3; c+=((x+y)%3)*t; t*=3; } return c; } ll cut3(ll a,ll b) { ll c=0,x=0,y=0; ll t=1; REP(i,0,K-1){ x=a%3;a/=3; y=b%3;b/=3; c+=((x-y+3)%3)*t; t*=3; } return c; } void dfs_d(ll u,ll f,ll dep) { d[++dn]=u; s[u]=dep; for(int i=first[u];~i;i=Next[i]){ int v=e[i]; if(v==f||vis[v]) continue; dfs_d(v,u,add3(dep,val[v])); } } ll get_rt(ll u,ll f,int sz) { ll cnt=1,balance1=0; for(int i=first[u];~i;i=Next[i]){ int v=e[i]; if(v==f||vis[v]) continue; ll tmp=get_rt(v,u,sz); cnt+=tmp; balance1=max(balance1,tmp); } balance1=max(balance1,sz-cnt); if(balance1<balance){ balance=balance1; rt=u; } return cnt; } ll solve(int u) { rt=u;balance=INF; ll sz=get_rt(u,0,N); rt=u;balance=INF; get_rt(u,0,sz); u=rt; vis[u]=1; ll res=0; id.clear(); s[u]=val[u]; id[s[u]]++; if(val[u]==0) res++; for(ll i=first[u];~i;i=Next[i]){ int v=e[i]; if(vis[v]) continue; dn=0; dfs_d(v,u,add3(val[u],val[v])); REP(j,1,dn){ ll idx=d[j],x=s[idx]; ll y=cut3(val[u],x); res+=id[y]; } REP(j,1,dn){ ll idx=d[j],x=s[idx]; id[x]++; } } for(int i=first[u];~i;i=Next[i]){ int v=e[i]; if(vis[v]) continue; res+=solve(v); } return res; } int main() { //freopen("in.txt","r",stdin); while(~scanf("%d%d",&N,&K)){ REP(i,0,K-1) scanf("%I64d",&p[i]); REP(i,1,N){ ll x;scanf("%I64d",&x); val[i]=toCube(x); //cout<<"i="<<i<<" val[i]="<<val[i]<<" x="<<x<<endl; } Init(); REP(i,1,N-1){ scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); } MS0(vis); printf("%I64d\n",solve(1)); } return 0; } /** 5 3 2 3 5 2500 200 9 270000 27 4 2 3 5 2 5 4 1 2 2 3 5 9 3 1 2 6 3 2 3 5 10 10 10 10 10 10 1 2 2 3 3 4 4 5 5 6 6 3 2 3 5 216 10 10 10 25 5 1 2 2 3 3 4 4 5 5 6 12 3 2 3 5 1 3 5 3 3 9 1 5 5 2 4 2 1 4 4 2 4 5 4 6 4 7 2 3 5 8 5 9 6 10 6 11 7 12 5 3 2 3 5 1 1 1 1 1 1 2 1 3 1 4 1 5 */