对快速傅里叶变换(FFT)的缺点进行了优化。
在计算多项式乘法(卷积)时,FFT设计三角函数、复数等很多恶心的东西,有着最大的缺点:精度问题,而在很多题目中往往需要进行取模,要求精度很高,FFT就不行了。
于是就有了快速数论变换
FFT之所以可以实现,是利用了单位复根 ω \omega ω的周期性质, ω n n = 1 , ω n k = ω n k + n \omega_n^n=1,\omega_n^k=\omega_n^{k+n} ωnn=1,ωnk=ωnk+n;
通过这个性质,可以把FFT后续所有步骤全部推导出来。
NTT由于需要取模,根据模数,我们可以重新定义一个类似于单位复根的东西,使它的幂有周期性,那就是原根
原根: ω n n ≡ 1 ( m o d p ) \omega_n^n\equiv 1(mod\space p) ωnn≡1(mod p),且没有 ( k = 1 , 2 , 3... , n − 1 ) (k=1,2,3...,n-1) (k=1,2,3...,n−1) ω n k ≡ 1 ( m o d p ) \omega_n^k\equiv 1(mod\space p) ωnk≡1(mod p)。
对于每个质数 p p p,令 g p − 1 ≡ 1 ( m o d p ) g^{p-1}\equiv 1(mod\space p) gp−1≡1(mod p),且 g k m o d p g^k\space mod\space p gk mod p都不为1, ( 1 ≤ k ≤ p − 2 ) (1≤k≤p-2) (1≤k≤p−2)
则 g p − 1 n g^{\frac {p-1} n} gnp−1就可以作为原根 ω n \omega_n ωn,满足FFT中单位复根的一切性质。
将FFT中所有单位复根换位原根,就可以实现NTT了。
//UOJ34
#include
#include
using namespace std;
const int MAXN=400005,MOD=998244353,G=3;
int PowMod(int a,int b)
{
int res=1;
for(;b;b>>=1,a=1LL*a*a%MOD)
if(b&1)
res=1LL*res*a%MOD;
return res;
}
void NTT(int A[],int n,int mode)
{
for(int i=0,j=0;i<n;i++)
{
if(i<j)swap(A[i],A[j]);
int k=n>>1;
while(k&j)
j^=k,k>>=1;
j^=k;
}
for(int i=1;i<n;i<<=1)
{
int w1=PowMod(ROOT,(MOD-1)/(i<<1));
if(mode==-1)
w1=PowMod(w1,MOD-2);
for(int j=0;j<n;j+=(i<<1))
for(int l=j,r=j+i,w=1;l<j+i;l++,r++,w=1LL*w*w1%MOD)
{
int tmp=1LL*A[r]*w%MOD;
A[r]=(A[l]-tmp+MOD)%MOD;
A[l]=(A[l]+tmp)%MOD;
}
}
if(mode==-1)
{
int invn=PowMod(n,MOD-2);
for(int i=0;i<n;i++)
A[i]=1LL*A[i]*invn%MOD;
}
}
void Multiply(const int A[],int len1,const int B[],int len2,int C[])
{
static int A0[MAXN*3],B0[MAXN*3];
int len=1;
for(;len<len1+len2-1;len<<=1);
for(int i=0;i<len1;i++)A0[i]=A[i];
for(int i=len1;i<len;i++)A0[i]=0;
for(int i=0;i<len2;i++)B0[i]=B[i];
for(int i=len2;i<len;i++)B0[i]=0;
NTT(A0,len,1);NTT(B0,len,1);
for(int i=0;i<len;i++)
A0[i]=1LL*A0[i]*B0[i]%MOD;
NTT(A0,len,-1);
for(int i=0;i<len1+len2-1;i++)C[i]=A0[i];
}
int A[MAXN],B[MAXN];
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%d",A+i);
for(int i=0;i<=m;i++)
scanf("%d",B+i);
Multiply(A,n+1,B,m+1,A);
for(int i=0;i<n+m;i++)
printf("%d ",A[i]);
printf("%d\n",A[n+m]);
return 0;
}