今早重新看了myy的论文,又掌握了一些多项式乘法的新姿势,于是写一篇blog巩固一下QAQ。
①如何用一次DFT加一次IDFT求出两个实序列A和B的卷积?
这里我们只要求卷积后的结果,不需要求DFT的值,所以有一种很简便的方法:令复数序列C的实部为A,虚部为B。将其自卷,所得结果虚部的值除以2就是要求的多项式。
这个十分容易证明:
②如何用一次DFT同时求出两个实序列在单位复数根处的点值?
这个推导就很复杂了,大概就是一堆三角函数和 i i 换来换去,具体要看myy的论文,我也不再赘述。最终概括一下做法,就是设:
令 P(x),Q(x) P ( x ) , Q ( x ) 在单位复数根处的点值表达分别为 FP,FQ F P , F Q ,则可以证明 FP(k) F P ( k ) 与 FQ(N−k) F Q ( N − k ) 互为共轭复数。因此只需要对 P(x) P ( x ) 进行DFT即可。然后会有:
(myy论文里第二条式子写的是乘以 i i ,不过我觉得是除以 i i 才对)
③关于单位复数根 ωkN ω N k :
这个有三种算法。一种是累乘,时间快但精度不高。还有一种是直接每次用 2πkN 2 π k N 的三角函数去算,这样比较慢,但精度高。最后一种是预处理,然后用vector之类的存下来。第三种方法比较折中,在拆系数+FFT处理任意模数多项式卷积的时候经常用。
贴一份模板题(洛谷P3803)的CODE:
因为只用了两次DFT,所以完全不怕卡常
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
const int maxn=4000000;
const double pi=acos(-1.0);
struct Complex
{
double X,Y;
Complex (double a=0.0,double b=0.0) : X(a),Y(b) {}
} ;
Complex operator+(Complex a,Complex b){return Complex(a.X+b.X,a.Y+b.Y);}
Complex operator-(Complex a,Complex b){return Complex(a.X-b.X,a.Y-b.Y);}
Complex operator*(Complex a,Complex b){return Complex(a.X*b.X-a.Y*b.Y,a.X*b.Y+a.Y*b.X);}
Complex A[maxn];
Complex B[maxn];
vector w[maxn];
int Rev[maxn];
int N,Lg;
int F[maxn];
int G[maxn];
int n,m;
void DFT(Complex *a,double f)
{
for (int i=0; iif (ifor (int len=2; len<=N; len<<=1)
{
int mid=(len>>1);
for (Complex *p=a; p!=a+N; p+=len)
for (int i=0; iif (f==-1.0) temp.Y=-temp.Y;
temp=temp*p[mid+i];
p[mid+i]=p[i]-temp;
p[i]=p[i]+temp;
}
}
}
void FFT()
{
N=1,Lg=0;
while (N4) N<<=1,Lg++;
for (int i=0; ifor (int j=0; jif (i&(1<1<<(Lg-j-1));
int len=1;
while ((len<<1)<=N)
{
double ang=pi/len;
for (int i=0; icos(ang*(double)i) , sin(ang*(double)i) ) );
len<<=1;
}
for (int i=0; idouble)F[i],(double)G[i]);
DFT(A,1.0);
for (int i=0; ifor (int i=0; i2.0;
a.Y/=2.0;
B[i]=A[i]-B[i];
B[i].X/=2.0;
B[i].Y/=2.0;
swap(B[i].X,B[i].Y);
B[i].Y=-B[i].Y;
A[i]=a;
}
for (int i=0; i1.0);
for (int i=0; idouble)N);
for (int i=0; iint)floor( A[i].X+0.5 );
}
int main()
{
freopen("3803.in","r",stdin);
freopen("3803.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=0; i<=n; i++) scanf("%d",&F[i]);
for (int i=0; i<=m; i++) scanf("%d",&G[i]);
FFT();
for (int i=0; i<=n+m; i++) printf("%d ",F[i]);
printf("\n");
return 0;
}