FFT



传送门:http://www.wikioi.com/problem/3123/

FFT,快速傅里叶变换,蒟蒻看别人的题解都太深奥,看不懂,好不容易学会,以蒟蒻的理解写给那些想学FFT却又找不到合适的资料的OIer,蒟蒻理解有限,难免有许多错误,请大家多多包涵。


快速傅里叶变换

百度的各种讲解都TM扯什么频率什么的,蒟蒻完全看不懂,后来认真看了看算导,获益匪浅,算导上讲的真心不赖,有很多内容都来自算导。

1.多项式

        多项式的两种表达方式:系数表达和点值表达

系数表达就是大家常用的表达方式,点值表达就像在这个多项式函数上取n个不同的点,这样就可以确定原多项式。

比如说二次函数需要3个点就可以确定,一次函数需要2个点,一个n次多项式需要n个点(n次多项式意思是有0..n-1次幂的多项式)

A(x)=x^2+2*x-1可以被表达为{  ( 0 , -1 ) , ( 1 , 2 ) , ( 2 , 7 )  }

FFT_第1张图片FFT_第2张图片

加法和乘法:

         FFT_第3张图片FFT_第4张图片

B(x)=x^2-x+2  { ( 0 , 2 ) , ( 1 , 2 ) , ( 2 , 4 ) }

     C(x)=A(x)+B(x)=2x^2+x+1   { ( 0, 1) , ( 1 , 4 ) , ( 2, 11 ) }

     注意乘法需要2n个点 lz比较懒就不写了……



于是我们得到一个计算多项式的方法:

FFT_第5张图片

 2.n次单位复数根

     FFT_第6张图片

有关复数根的性质可以百度到,不再赘述

http://baike.baidu.com/link?url=017EPfseoBwVxWpWPm5aunUn8x9dmRvioav9IubYLSKEGngK8_rDV2bd4PFCM8sJ

3.DFT&&FFT

使用单位根计算点值表达式叫DFT(离散傅里叶变换)复杂度n^2,FFT是其优化版复杂度nlogn

FFT_第7张图片FFT_第8张图片

FFT_第9张图片

计算FFT的伪代码(好吧用的是python的高亮)

下划线代表的是下标,括号代表上标,for 循环的range是左闭右开的

[python] view plain copy print ?
  1. FFT(a):  
  2.     n=a.length()  
  3.     if n==1:  
  4.         return a  
  5.     w_n=e^(pi*i/n)=complex(cos(2*pi/n),sin(2*pi/n))  
  6.     w=1  
  7.     a(0)=[a0,a2,...a_n-2]  
  8.     a(1)=[a1,a3,...a_n-1]  
  9.     y(0)=FFT(a(0))  
  10.     y(1)=FFT(a(1))  
  11.     for k in range(0,n/2):  
  12.         y_k=y_k(0)+w*y_k(1)  
  13.         y_k+n/2=y_k(0)-w*y_k(1)  
  14.         w=w*w_n  
  15.     return y  

4.递归=>迭代??


FFT的for循环中有两次w_n^k*y_k(1)的计算,于是可以改写成这样

 

[python] view plain copy print ?
  1. for k in range(0,n/2):  
  2.     t=w*y_k(1)  
  3.     y_k=y_k(0)+t  
  4.     y_k+n/2=y_k(0)-t  
  5.     w=w*w_n  
  6. #这一过程被称蝴蝶操作  

观察每次按照奇偶位置分割所形成的树:

FFT_第10张图片

每个数和他二进制相反的位置互换!!

伪代码(算导给的真是……)

[python] view plain copy print ?
  1. BIT-REVERSE-COPY(a,A):  
  2.     n=a.length()  
  3.     for k in range(0,n):  
  4.         A[rev(k)]=a_k  
  5. #算导说rev函数很好写,就没写……  

于是我们给出FFT的迭代实现的伪代码:

[python] view plain copy print ?
  1. FFT(a):  
  2.     BIT-REVERSE-COPY(a,A)  
  3.     n=a.length()  
  4.     for s in range(1,log2(n)+1):  
  5.         m=2^s  
  6.         w_m=e^(2*pi*i/m)=complex(cos(2*pi*m),sin(2*pi*m))  
  7.         for k in range(0,n,m):  
  8.             w=1  
  9.             for j in range(0,m/2):  
  10.                 t=w*A[k+j+m/2]  
  11.                 u=A[k+j]  
  12.                 A[k+j]=u+t  
  13.                 A[k+j+m/2]=u-t  
  14.                 w=w*w_m  
  15.     return A  
差不多讲完了,最后给出C++代码,有一大部分是lz借鉴别人的Code,以后附上地址

[cpp] view plain copy print ?
  1. #include<bitset>  
  2. #include <cstdio>  
  3. #include <cstring>  
  4. #include <cmath>  
  5. #include <algorithm>  
  6. #define N 400005  
  7. #define pi acos(-1.0) // PI值  
  8. using namespace std;  
  9. struct complex  
  10. {  
  11.     double r,i;  
  12.     complex(double real=0.0,double image=0.0){  
  13.         r=real; i=image;  
  14.     }  
  15.     // 以下为三种虚数运算的定义  
  16.     complex operator + (const complex o){  
  17.         return complex(r+o.r,i+o.i);  
  18.     }  
  19.     complex operator - (const complex o){  
  20.         return complex(r-o.r,i-o.i);  
  21.     }  
  22.     complex operator * (const complex o){  
  23.         return complex(r*o.r-i*o.i,r*o.i+i*o.r);  
  24.     }  
  25. }x1[N],x2[N];  
  26. char a[N/2],b[N/2];  
  27. int sum[N]; // 结果存在sum里  
  28. int vis[N];  
  29. void brc(complex *a,int l){//原来神犇的二进制平摊反转置换太神看不懂,蒟蒻写了一个O(n)的……   
  30.     memset(vis,0,sizeof(vis));//O(logn)的在后面   
  31.     for(int i=1;i<l-1;i++){  
  32.         int x=i,y=0;  
  33.         int m=(int)log2(l)+0.1;  
  34.         if(vis[x])continue;  
  35.         while(m--){  
  36.             y<<=1;  
  37.             y|=(x&1);  
  38.             x>>=1;  
  39.         }  
  40.         vis[i]=vis[y]=1;  
  41.         swap(a[i],a[y]);  
  42.     }     
  43. }  
  44. void fft(complex *y,int l,double on) // FFT O(nlogn)  
  45.                             // 其中on==1时为DFT,on==-1为IDFT  
  46. {  
  47.     register int h,i,j,k;  
  48.     complex u,t;   
  49.     brc(y,l); // 调用反转置换  
  50.     for(h=2;h<=l;h<<=1) // 控制层数  
  51.     {  
  52.         // 初始化单位复根  
  53.         complex wn(cos(on*2*pi/h),sin(on*2*pi/h));  
  54.         for(j=0;j<l;j+=h) // 控制起始下标  
  55.         {  
  56.             complex w(1,0); // 初始化螺旋因子  
  57.             for(k=j;k<j+h/2;k++) // 配对  
  58.             {  
  59.                 u=y[k];  
  60.                 t=w*y[k+h/2];  
  61.                 y[k]=u+t;  
  62.                 y[k+h/2]=u-t;  
  63.                 w=w*wn; // 更新螺旋因子  
  64.             } // 据说上面的操作叫蝴蝶操作…  
  65.         }  
  66.     }  
  67.     if(on==-1)  for(i=0;i<l;i++) y[i].r/=l; // IDFT  
  68. }  
  69. /*  
  70. void fft2(complex *a,int s,int t){//蒟蒻自己写的递归版FFT,不保证正确 ,代码内部有未定义变量  
  71.     if((n>>t)==1)return;//s记录起始,t记录深度,调用时应从0开始  
  72.     fft(a,s,t+1); 
  73.     fft(a,s+(1<<t),t+1); 
  74.     for(int i=0;i<(n>>(t+1));i++){ 
  75.         p=(i<<(t+1))+s; 
  76.         wt=w[i<<t]*a[p+(1<<t)]; 
  77.         tt[i]=a[p]+wt; 
  78.         tt[i+(n>>(t+1))]=a[p]-wt; 
  79.     } 
  80.     for(i=0;i<(n>>t);i++)a[(i<<t)+s]=tt[i]; 
  81. }*/  
  82. int main(void)  
  83. {  
  84.     int l1,l2,l;  
  85.     register int i;  
  86.     while(scanf("%s%s",a,b)!=EOF)  
  87.     {  
  88.         l1=strlen(a);  
  89.         l2=strlen(b);  
  90.         l=1;  
  91.         while(l<l1*2 || l<l2*2)   l<<=1; // 将次数界变成2^n  
  92.                                         // 配合二分与反转置换  
  93.         for(i=0;i<l1;i++) // 倒置存入  
  94.         {  
  95.             x1[i].r=a[l1-i-1]-'0';  
  96.             x1[i].i=0.0;  
  97.         }  
  98.         for(;i<l;i++)    x1[i].r=x1[i].i=0.0;  
  99.         // 将多余次数界初始化为0  
  100.         for(i=0;i<l2;i++)  
  101.         {  
  102.             x2[i].r=b[l2-i-1]-'0';  
  103.             x2[i].i=0.0;  
  104.         }  
  105.         for(;i<l;i++)    x2[i].r=x2[i].i=0.0;  
  106.         fft(x1,l,1); // DFT(a)  
  107.         fft(x2,l,1); // DFT(b)  
  108.         for(i=0;i<l;i++) x1[i]=x1[i]*x2[i]; // 点乘结果存入a  
  109.         fft(x1,l,-1); // IDFT(a*b)  
  110.         for(i=0;i<l;i++) sum[i]=x1[i].r+0.5; // 四舍五入  
  111.         for(i=0;i<l;i++) // 进位  
  112.         {  
  113.             sum[i+1]+=sum[i]/10;  
  114.             sum[i]%=10;  
  115.         }  
  116.         l=l1+l2-1;  
  117.         while(sum[l]<=0 && l>0)   l--; // 检索最高位  
  118.         for(i=l;i>=0;i--)    putchar(sum[i]+'0'); // 倒序输出  
  119.         putchar('\n');  
  120.     }  
  121.     return 0;  
  122. }  
  123. /*void brc(complex *y,int l) // 二进制平摊反转置换 O(logn) 
  124. { 
  125.     register int i,j,k; 
  126.     for(i=1,j=l/2;i<l-1;i++) 
  127.     { 
  128.         if(i<j)  swap(y[i],y[j]); // 交换互为下标反转的元素 
  129.                                 // i<j保证只交换一次 
  130.         k=l/2; 
  131.         while(j>=k) // 由最高位检索,遇1变0,遇0变1,跳出 
  132.         { 
  133.             j-=k; 
  134.             k>>=1; 
  135.         } 
  136.         if(j<k)  j+=k; 
  137.     } 
  138. }*/  

pyc神犇的写法,bzoj3527,zjoi2014 力,无限YM

[cpp] view plain copy print ?
  1. #include<cmath>  
  2. #include<cstdio>  
  3. #include<cstring>  
  4. #include<iostream>  
  5. #include<algorithm>  
  6. using namespace std;  
  7. const int maxn=1000010;  
  8. int n,N,L;  
  9. int rev[maxn];  
  10. int dig[maxn];  
  11. double p[maxn];  
  12. struct cp{  
  13.     double r,i;  
  14.     cp(double _r=0,double _i=0):  
  15.         r(_r),i(_i){}  
  16.     cp operator+(cp x){return cp(r+x.r,i+x.i);}  
  17.     cp operator-(cp x){return cp(r-x.r,i-x.i);}  
  18.     cp operator*(cp x){return cp(r*x.r-i*x.i,r*x.i+i*x.r);}  
  19. };  
  20. cp a[maxn],b[maxn],c[maxn],A[maxn],x,y;  
  21. void FFT(cp a[],int flag){  
  22.     for(int i=0;i<N;i++)A[i]=a[rev[i]];  
  23.     for(int i=0;i<N;i++)a[i]=A[i];  
  24.     for(int i=2;i<=N;i<<=1){  
  25.         cp wn(cos(2*M_PI/i),flag*sin(2*M_PI/i));  
  26.         for(int k=0;k<N;k+=i){  
  27.             cp w(1,0);  
  28.             for(int j=0;j<i/2;j++){  
  29.                 x=a[k+j];  
  30.                 y=w*a[k+j+i/2];  
  31.                 a[k+j]=x+y;  
  32.                 a[k+j+i/2]=x-y;  
  33.                 w=w*wn;    
  34.             }  
  35.         }  
  36.     }  
  37.     if(flag==-1)for(int i=0;i<N;i++)a[i].r/=N;  
  38. }  
  39. double anss[maxn];  
  40. int main(){  
  41.     scanf("%d",&n);  
  42.     for(int i=0;i<n;i++)scanf("%lf",&p[i]);  
  43.     for(L=0,N=1;N<n;N<<=1,L++);L++;N<<=1;  
  44.     for(int i=0;i<N;i++){  
  45.         int len=0;  
  46.         for(int t=i;t;t>>=1)dig[len++]=t&1;  
  47.         for(int j=0;j<L;j++)rev[i]=rev[i]*2+dig[j];  
  48.     }  
  49.     for(int i=0;i<n;i++)a[i]=cp(p[i],0);  
  50.     for(int i=1;i<n;i++)b[i]=cp(1.0/i/i,0);  
  51.     FFT(a,1);FFT(b,1);  
  52.     for(int i=0;i<N;i++)c[i]=a[i]*b[i];  
  53.     FFT(c,-1);  
  54.     for(int i=0;i<n;i++)anss[i]=c[i].r;  
  55.     memset(a,0,sizeof(a));  
  56.     memset(b,0,sizeof(b));  
  57.     for(int i=0;i<n;i++)a[i]=cp(p[n-i-1],0);  
  58.     for(int i=1;i<n;i++)b[i]=cp(1.0/i/i,0);  
  59.     FFT(a,1);FFT(b,1);  
  60.     for(int i=0;i<N;i++)c[i]=a[i]*b[i];  
  61.     FFT(c,-1);  
  62.     for(int i=0;i<n;i++)anss[i]-=c[n-i-1].r;  
  63.     for(int i=0;i<n;i++)  
  64.         printf("%.9f\n",anss[i]);  
  65.     return 0;  
  66. }  


重新过了一遍高精乘

[cpp] view plain copy print ?
  1. #include<cstdio>  
  2. #include<cmath>  
  3. #include<cstring>  
  4. #include<iostream>  
  5. #include<algorithm>  
  6. using namespace std;  
  7. const int maxn=1e6+10;  
  8. struct cp{  
  9.     double r,i;  
  10.     cp(double _r=0,double _i=0):  
  11.         r(_r),i(_i){}  
  12.     cp operator+(cp x){return cp(r+x.r,i+x.i);}  
  13.     cp operator-(cp x){return cp(r-x.r,i-x.i);}  
  14.     cp operator*(cp x){return cp(r*x.r-i*x.i,r*x.i+i*x.r);}  
  15. };  
  16. cp a[maxn],b[maxn],A[maxn],x,y,c[maxn];  
  17. char s1[maxn],s2[maxn];  
  18. int sum[maxn],a1[maxn],a2[maxn],dig[maxn];  
  19. int len1,len2,rev[maxn],N,L;  
  20. void FFT(cp a[],int flag){  
  21.     for(int i=0;i<N;i++)A[i]=a[rev[i]];  
  22.     for(int i=0;i<N;i++)a[i]=A[i];  
  23.     for(int i=2;i<=N;i<<=1){  
  24.         cp wn(cos(2*M_PI/i),flag*sin(2*M_PI/i));  
  25.         for(int k=0;k<N;k+=i){  
  26.             cp w(1,0);  
  27.             for(int j=k;j<k+i/2;j++){  
  28.                 x=a[j];  
  29.                 y=a[j+i/2]*w;  
  30.                 a[j]=x+y;  
  31.                 a[j+i/2]=x-y;  
  32.                 w=w*wn;  
  33.             }  
  34.         }  
  35.     }  
  36.     if(flag==-1)for(int i=0;i<N;i++)a[i].r/=N;  
  37. }  
  38. int main(){  
  39.     scanf("%s%s",s1,s2);  
  40.     len1=strlen(s1);  
  41.     len2=strlen(s2);  
  42.     for(N=1,L=0;N<max(len1,len2);N<<=1,L++);N<<=1;L++;  
  43.     for(int i=0;i<N;i++){  
  44.         int len=0;  
  45.         for(int t=i;t;t>>=1)dig[len++]=t&1;  
  46.         for(int j=0;j<L;j++)rev[i]=(rev[i]<<1)|dig[j];  
  47.     }  
  48.     for(int i=0;i<len1;i++)a1[len1-i-1]=s1[i]-'0';  
  49.     for(int i=0;i<len2;i++)a2[len2-i-1]=s2[i]-'0';  
  50.     for(int i=0;i<N;i++)a[i]=cp(a1[i]);  
  51.     for(int i=0;i<N;i++)b[i]=cp(a2[i]);  
  52.     FFT(a,1);FFT(b,1);  
  53.     for(int i=0;i<N;i++)c[i]=a[i]*b[i];  
  54.     FFT(c,-1);  
  55.     for(int i=0;i<N;i++)sum[i]=c[i].r+0.5;  
  56.     for(int i=0;i<N;i++){  
  57.         sum[i+1]+=sum[i]/10;  
  58.         sum[i]%=10;  
  59.     }  
  60.     int l=len1+len2-1;  
  61.     while(sum[l]==0&&l>0)l--;  
  62.     for(int i=l;i>=0;i--)  
  63.     putchar(sum[i]+'0');  
  64.     putchar('\n');  
  65.     return 0;  
  66. }  

你可能感兴趣的:(FFT)