【题目描述】一个2*i的矩阵,一共有m种颜色,相邻两个格子颜色不能相同,m种颜色不必都用上,f[i]表示这个答案,求Σf[i]*(2*i)^m (1<=i<=n)%p。
【数据范围】
20% n,m<10^5 p<10^9
其余 n<10^9
其中40% m<100 p<10^9
20% m<10^3 p<10^9
20% m<10^4 p<10^3
首先我们可以推导出f[i]的式子,f[1]=m*(m-1),因为其余的格子,我们第一个格子可以放m-1种颜色,第二个格子在第一个格子和上一层第二个格子颜色不同时有m-2种情况,相同时有m-1种情况,那么我们可以得出f[i]=f[i-1]*(m^2-3*m+3),设a=m^2-3*m+3,那么问题就变成了求2^m*m*(m-1)/a*Σa^i*i^m (1<=i<=n)。
对于前20%的数据,我们可以暴力的nlogm递推求解。
对于中间的20%数据,我们可以化简一下求解式。
设w[i][j]为Σa^j*j^k (1<=j<=i),那么答案就成了2^m*m*(m-1)/a*w[n][m],现在的问题就是求解w[n][m]。
w[n][m]=Σa^i*i^m (1<=i<=n),w[n+1][m]=Σa^i*i^m (1<=i<=n+1)=a+Σa^i*i^m (2<=i<=n+1)=a+a*Σa^i*(i+1)^m (1<=i<=n)
那么我们可以根据二项式展开来化简,w[n+1][m]=a+a*Σa^iΣc(m,j)*i^j (1<=i<=n) (0<=j<=m)=a+a*Σc(m,j)*w[n][j] (0<=j<=m),那么这样我们就可以写一个矩阵来加速在n^3logn的时间内求解了。
对于m<10^3的情况,w[n+1][m]=w[n][m]+a^(n+1)*(i+1)^m。根据刚才的推导,w[n+1][m]=a+a*Σc(m,j)*w[n][j] (0<=j<=m)。所以我们可以得到
w[n][m]+a^(n+1)*(i+1)^m=a+a*Σc(m,j)*w[n][j] (0<=j<=m),这样我们发现,如果我们知道了w[n][m]之前的w[n][j],那么我们可以在m的时间内反解出w[n][j]。
对于最后的20%,我们发现p非常小,考虑答案的求和式Σa^i*i^m,发现这个式子之和i和i^m有关,当i,i^m和j,j^m关于mod p相等之后,那么i与j之后的变换是相同的,我们可以发现一共有p^2个不同的情况,那么我们记下来这个循环之后就可以算出来了。
//By BLADEVIL #include <cstdio> #define LL long long #define maxp 1010 #define maxm 1010 using namespace std; int n,m,p,a; int mo[maxp],next[maxp],w[maxm],c[maxm][maxm]; int flag[maxp][maxp],f[maxp*maxp],g[maxp*maxp]; int mi(int a,int k) { int ans=1; while (k) { if (k&1) ans=((LL)ans*a)%p; a=((LL)a*a)%p; k>>=1; } return ans; } void work1() { int ans=0; ans=((LL)m*(m-1))%p; ans=((LL)ans*mi(2,m))%p; ans=((LL)ans*mi(a,p-2))%p; //printf("%d\n",ans); int cur=0; for (int i=1;i<=n;i++) cur=((LL)cur+(LL)mi(a,i)*mi(i,m))%p; ans=((LL)ans*cur)%p; printf("%d\n",ans); } void work2() { int i,l,r; f[1]=(LL)m*(m-1)%p; g[1]=(LL)f[1]*mi(2,m)%p; flag[1][f[1]]=1; for (i=2;i<=n;i++) { f[i]=(LL)f[i-1]*a%p; g[i]=(LL)f[i]*mi(2*i,m)%p; if (flag[i%p][f[i]]) { l=flag[i%p][f[i]]; r=i-1; break; } flag[i%p][f[i]]=i; } //printf("%d %d %d %d\n",l,r,i,g[i]); int len=r-l+1,ans=0,tmp=0; for (int j=1;j<l;j++) ans=(ans+g[j])%p; n-=l-1; for (int j=l;j<=r;j++) tmp=(tmp+g[j])%p; ans=(ans+(LL)tmp*(n/len)%p)%p; for (int j=1;j<=n%len;j++) ans=(ans+g[l+j-1])%p; printf("%d\n",ans); } void work3() { for (int i=1;i<=m;i++) { c[i][0]=c[i][i]=1; for (int j=1;j<i;j++) c[i][j]=(c[i-1][j]+c[i-1][j-1])%p; } int ans=0; ans=((LL)m*(m-1))%p; ans=((LL)ans*mi(2,m))%p; ans=((LL)ans*mi(a,p-2))%p; //printf("%d\n",ans); //for (int i=1;i<=m;i++) printf("%d ",c[m][i]); printf("\n"); w[0]=((LL)mi(a,n+1)-a+p)%p; w[0]=((LL)w[0]*mi(a-1,p-2))%p; //printf("%d\n",w[0]); for (int i=1;i<=m;i++) { w[i]=((LL)mi(a,n+1)*mi(n+1,i)-a+p)%p; for (int j=0;j<i;j++) w[i]=((LL)w[i]-((LL)a*c[i][j]%p*w[j]%p)+p)%p; w[i]=(LL)w[i]*mi(a-1,p-2)%p; w[i]%=p; } //printf("%d\n",w[m]); ans=((LL)ans*w[m])%p; printf("%d\n",ans); } int main() { freopen("color.in","r",stdin); freopen("color.out","w",stdout); scanf("%d%d%d",&n,&m,&p); a=((LL)m*m-3*m+3)%p; if (n<=100000) work1(); else if (m>1000) work2(); else work3(); fclose(stdin); fclose(stdout); return 0; }