把正整数M分解成至多N份且每份不为0(注意1+2+3与2+3+1是不一样的即存在顺序性),一份x的价值是f(x)=a2*x*x+a1*x+a0,总价值为每一份价值的乘积。求所有情况下总价值的和,答案模mo。
M<=10000,mo<=255,N<=10^8,a2<=4,a1<=300,a0<=100
设g[i,j]表示把j分解成i份的总价值和。
显然 g[i,j]=∑j−1k=1g[i−1,j−k]∗f[k]
发现这是标准卷积形式,也就是说
g[i+1]=g[i]∗f
假如我们只需要分成N份,那么只需要把g快速幂即可,但是有一个“至多”。
我们设 p[i,j]=∑ik=1g[k,j]
因为g要倍增,那我们可不可以让p也倍增?
事实证明可以。
p[i,j]=p[i/2,j]+∑i/2k=1g[i/2+k,j]
p[i,j]=p[i/2,j]+∑i/2k=1∑j−1l=1g[k,l]∗g[i/2,j−l]
p[i,j]=p[i/2,j]+∑j−1l=1∑i/2k=1g[k,l]∗g[i/2,j−l]
p[i,j]=p[i/2,j]+∑j−1l=1p[i/2,l]∗g[i/2,j−l]
后面那部分是标准卷积形式!
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef double db;
struct node{
db x,y;
node friend operator +(node a,node b){
node c;
c.x=a.x+b.x;c.y=a.y+b.y;
return c;
}
node friend operator -(node a,node b){
node c;
c.x=a.x-b.x;c.y=a.y-b.y;
return c;
}
node friend operator *(node a,node b){
node c;
c.x=a.x*b.x-a.y*b.y;c.y=a.x*b.y+a.y*b.x;
return c;
}
};
const db pi=acos(-1);
const int maxm=10000*4;
int b[maxm],g[maxm],h[maxm],p[maxm];
node e[maxm],f[maxm],tt[maxm];
db ce;
int i,j,k,l,t,n,m,mo,len,a2,a1,a0;
void DFT(node *a,int sig){
fo(i,0,len-1){
int p=0;
for(int j=0,tp=i;j<ce;j++,tp/=2) p=(p<<1)+(tp%2);
tt[p]=a[i];
}
for(int m=2;m<=len;m*=2){
int half=m/2; fo(i,0,half-1){ node w; w.x=cos(i*sig*pi/half),w.y=sin(i*sig*pi/half);
for(int j=i;j<len;j+=m){
node u=tt[j],v=tt[j+half]*w;
tt[j]=u+v;
tt[j+half]=u-v;
}
}
}
if (sig==-1)
fo(i,0,len-1) tt[i].x/=len;
fo(i,0,len-1) a[i]=tt[i];
}
void FFT(int *a,int *b,int *c){
int i;
fo(i,0,len-1) e[i].x=a[i],f[i].x=b[i],e[i].y=f[i].y=0;
DFT(e,1);DFT(f,1);
fo(i,0,len-1) e[i]=e[i]*f[i];
DFT(e,-1);
fo(i,0,m) c[i]=round(e[i].x);
fo(i,0,m) c[i]%=mo;
}
void solve(int n){
int i;
if (n==1){
fo(i,0,m) p[i]=g[i]=b[i];
return;
}
solve(n/2);
FFT(p,g,h);
fo(i,0,m) p[i]=(p[i]+h[i])%mo;
FFT(g,g,g);
if (n%2){
FFT(g,b,g);
fo(i,0,m) p[i]=(p[i]+g[i])%mo;
}
}
int main(){
scanf("%d%d%d%d%d%d",&m,&mo,&n,&a2,&a1,&a0);
len=1;
while (len<m*2) len*=2;
ce=db(log(len)/log(2));
b[0]=0;
fo(i,1,m) b[i]=(a2*i%mo*i%mo+a1*i%mo+a0%mo)%mo;
solve(n);
printf("%d\n",p[m]);
}