题目大意:给一棵N个节点的树及正整数K,对每一个节点i求Σdist(i,j)^K。N<=5*10^4,K<=150。
O(NK^2),O(NKlogK)的做法都可以在贾志鹏2011年的集训队作业里看到。下面介绍一种O((N+K)K)的做法。
首先讲讲第二类斯特林数。S(N,K)表示将N个元素划分成K个非空子集的方案数。S(N,K)可以表示成递推形式:S(N,K)=S(N-1,K-1)+K*S(N-1,K)。N<=5的斯特林数大概长这样:
1
1 1
1 3 1
1 7 6 1
1 15 25 10 1
记[x]_i=x*(x-1)*(x-2)*...*(x-i+1)
X^1=1*[x]_1
X^2=1*[x]_2 + 1*[x]_1
X^3=1*[x]_3 + 3*[x]_2 + 1*[x]_1
X^4=1*[x]_4 + 6*[x]_3 + 7*[x]_2 + 1*[x]_1
....
可以看出x^n=ΣS(N,k)*[x]_k。具体证明可以看《Concrete Mathematics》中关于Stirling数的讲解。
接着,我们有了以上结论就要想办法用上。直观的想法是对于每个点i求出f[i][j]=Σ[dist(i,x)]_j,(1<=x<=n,j<=K),那么ans[i]=ΣS(K,j)*f[i][j]。但是这个难以维护的原因在于我们难以从长度L-1求出长度L的f值。
试图改变一下所要求的f[i][j]。注意到[x]_k和组合数的关系,C(x,k)=[x]_k/k!,而且组合数有个优良的性质C(x,k)=C(x-1,k-1)+C(x-1,k)。这样我们记f[i][j]=ΣC(dist(i,x),j),(1<=x<=n,j<=K),就可以在长度之间转移了。
至此我们只需简单利用树形DP(两边DFS)求出f[i][j],然后ans[i]=ΣS(K,j)*f[i][j]*j!。时间和空间复杂度都是O((N+K)K)。
code:
#include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> #include <cmath> using namespace std; const int maxn=50003,maxk=155,mod=10007; int n,K,Link[maxn],pre[maxn*2],t[maxn*2],f[maxn][maxk],S[maxk][maxk],ans[maxn],g[maxk],h[maxk],fac[maxn]; inline void add(int &x,int y) { if(y>=mod)y-=mod; x+=y; if(x>=mod)x-=mod; } inline void sub(int &x,int y) { if(y>=mod)y-=mod; x-=y; if(x<0)x+=mod; } void dfs1(int x,int pa) { f[x][0]=1; for(int i=Link[x];i;i=pre[i]) { if(t[i]==pa)continue; dfs1(t[i],x); add(f[x][0],f[t[i]][0]); for(int j=1;j<=K;j++)add(f[x][j],f[t[i]][j-1]+f[t[i]][j]); } } void dfs2(int x,int pa) { for(int i=Link[x];i;i=pre[i]) { if(t[i]==pa)continue; int y=t[i]; g[0]=f[x][0],sub(g[0],f[y][0]); h[0]=n; for(int i=1;i<=K;i++) { g[i]=f[x][i],sub(g[i],f[y][i]+f[y][i-1]); h[i]=(f[y][i]+g[i]+g[i-1])%mod; } for(int i=0;i<=K;i++) { f[y][i]=h[i]; add(ans[y],S[K][i]*fac[i]%mod*h[i]%mod); } dfs2(t[i],x); } } int main() { int now,A,B,Q,L; scanf("%d%d%d",&n,&K,&L); scanf("%d%d%d%d",&now,&A,&B,&Q); for(int i=1,tot=0;i<n;i++) { now=(now*A+B)%Q; int x=i-now%((i<L)?i:L),y=i+1; pre[++tot]=Link[x]; Link[x]=tot; t[tot]=y; pre[++tot]=Link[y]; Link[y]=tot; t[tot]=x; } S[0][0]=fac[0]=1; for(int i=1;i<=K;fac[i]=fac[i-1]*i%mod,i++) for(int j=1;j<=i;j++)add(S[i][j],(S[i-1][j-1]+j*S[i-1][j])%mod); dfs1(1,0); for(int i=0;i<=K;i++)add(ans[1],S[K][i]*fac[i]%mod*f[1][i]%mod); dfs2(1,0); for(int i=1;i<=n;i++)printf("%d\n",ans[i]); return 0; }