给出一个数组A,经过一次处理,生成一个数组S,数组S中的每个值相当于数组A的累加,比如:A = {1 3 5 6} => S = {1 4 9 15}。如果对生成的数组S再进行一次累加操作,{1 4 9 15} => {1 5 14 29},现在给出数组A,问进行K次操作后的结果。(输出结果 Mod 10^9 + 7)
2 <= n <= 50000, 0 <= k <= 10^9, 0 <= a[i] <= 10^9
首先有个结论就是 ans[n]=∑ni=0Cii+k−1∗a[n−i] a n s [ n ] = ∑ i = 0 n C i + k − 1 i ∗ a [ n − i ] 。
这条式子是怎么来的呢?我的理解就是,把每一次前缀和后的数组放在一起,初始数组在第0行,前缀和放在第一行,如此类推。那么a[i]对a[n]贡献的系数就相当于每次可以从(x,y)走到(x+1,y+l)(l>=0),从(0,i)走到(k,n)的不同方案数。根据隔板法不难得到系数就是 Cn−in−i+k−1 C n − i + k − 1 n − i 。
上面的式子是一个卷积的形式,可以用FFT来优化。但注意到这里的模数并不是NTT模数,那么就要用到任意模数FFT。
具体来讲就是取 M=P−−√ M = P ,设 f(x)=k(x)∗M+r(x) f ( x ) = k ( x ) ∗ M + r ( x ) ,那么有
f1(x)∗f2(x)=(k1(x)∗M+r1(x))∗(k2(x)∗M+r2(x)) f 1 ( x ) ∗ f 2 ( x ) = ( k 1 ( x ) ∗ M + r 1 ( x ) ) ∗ ( k 2 ( x ) ∗ M + r 2 ( x ) )
=M2∗k1(x)∗k2(x)+M∗(k1(x)∗r2(x)+k2(x)∗r1(x))+r1(x)∗r2(x) = M 2 ∗ k 1 ( x ) ∗ k 2 ( x ) + M ∗ ( k 1 ( x ) ∗ r 2 ( x ) + k 2 ( x ) ∗ r 1 ( x ) ) + r 1 ( x ) ∗ r 2 ( x )
总共需要做7次FFT且卷积后元素的数量级为O(nP),可以用long double来进行FFT而不是高精度。
一开始打完后被卡了精度,后面把pi从define变成const就过了。。。
第一次手打复数类型,发现比自带的要快将近一倍。。。
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long LL;
const int MOD=1000000007;
const int N=50005;
const long double pi=acos((long double)-1);
struct com
{
long double x,y;
com operator + (const com &d) {return (com){x+d.x,y+d.y};}
com operator - (const com &d) {return (com){x-d.x,y-d.y};}
com operator * (const com &d) {return (com){x*d.x-y*d.y,x*d.y+d.x*y};}
com operator / (const long double &d) {return (com){x/d,y/d};}
};
int n,M,ny[N],c[N],a[N],k,L,rev[N*4];
com s1[N*4],s2[N*4],s3[N*4],k1[N*4],r1[N*4],k2[N*4],r2[N*4];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void init()
{
n=read();k=read();
for (int i=0;i0]=ny[1]=c[0]=1;c[1]=k;
for (int i=2;i1]*(k+i-1)%MOD;
for (int i=2;i1]%MOD,c[i]=(LL)c[i]*ny[i]%MOD;
}
void fft(com *a,int f)
{
for (int i=0;iif (ifor (int i=1;i1)
{
com wn=(com){cos(pi/i),f*sin(pi/i)};
for (int j=0;j1))
{
com w=(com){1,0};
for (int k=0;kif (f==-1) for (int i=0;iint main()
{
init();M=sqrt(MOD);
int lg=0;
for (L=1;L<=n*2;L<<=1,lg++);
for (int i=0;i>1]>>1)|((i&1)<<(lg-1));
for (int i=0;i0.0},r1[i]=(com){a[i]%M,0.0},k2[i]=(com){c[i]/M,0.0},r2[i]=(com){c[i]%M,0.0};
fft(k1,1);fft(r1,1);fft(k2,1);fft(r2,1);
for (int i=0;i1);fft(s2,-1);fft(s3,-1);
for (int i=0;iint x1=(LL)(s1[i].x+0.5)%MOD*M*M%MOD,x2=(LL)(s2[i].x+0.5)%MOD*M%MOD,x3=(LL)(s3[i].x+0.5)%MOD;
int ans=x1+x2;ans-=ans>=MOD?MOD:0;ans+=x3;ans-=ans>=MOD?MOD:0;
printf("%d\n",ans);
}
return 0;
}