DFT和FFT详解(算法导论学习笔记)

代码均为做严格测试,仅供参考

分治法基本原理

将原问题分解为几个规模较小但类似于原问题的子问题,递归的求解这些子问题。然后再合并这些子问题的解来建立原问题的解。递归求解这些子问题,然后再合并这些子问题的解来建立原问题的解。

分治法在分层递归时都有三个步骤:

  • 分解原问题为若干子问题,这些子问题是原问题规模较小的实例。
  • 解决这些子问题,递归的求解各个子问题。然而若子问题的规模足够小。则直接求解。
  • 合并这些子问题的解成原问题的解。

问题描述

两个N次多项式相乘,最直接的复杂度为 O(n2) ,运用傅里叶变换,则可以吧多项式相乘的复杂度转化为 nlog(n)

输入输出均采用系数表达,假设n是2的幂,否则通过添加系数为0的高阶系数。算法准备过程如下:

  • 加倍次数界:通过添加n个系数为0的高阶系数,把多项式A(x)和B(x)变为次数界为2n的多项式,并构造其系数表达。
  • 求值:通过应用2n阶的FFT计算出A(x)和B(x)的长度为2n的点值表达。这些点值表达式中包含了两个多项式在2n次单位根处的取值。
  • 逐点相乘:把A(x)和B(x)的值逐点相乘,可以计算出多项式C(x)=A(x)B(x)长度为2n的点值表达,这个表示中包含了C(x)在每个2n单位根处的值。
  • 插值:通过对2n个点值对应用fft,计算其逆DFT,就可以构造出多项式C(x)的系数表达。

基本概念和定理

  1. 单位复数根

    n次单位复数根是满足 ωn=1 的复数 ω ,这些根正好有n个,分别是 e2πikn(k=0,1,2...n1)

    其中 e2πin 被称作主n次单位根。 ωjωk=ω(k+j)modn

  2. 消去引理

    对于任意整数n>=0和k>=0,以及d>=0

    ωdkdn=ωkn

  3. 折半引理

    如果n>0为偶数,那么n个n次单位复数根的平方的集合就是n/2个n/2次单位复数根的集合。

    对任意非负整数k,我们有 (ωkn)2=ωkn/2

  4. 求和引理

    对任意整数n>=1和不能被n整除的非负整数k,有

    n1j=0(ωkn)j=0

算法实现

DFT

我们希望计算次数界为n的多项式A(x)= n1j=0ajxj

在n个n次单位复数根处的值,假设A以系数形式给出: a=(a0,a1,a2...an1) 。接下来对k=0,1,2,..n-1,定义结果 yk :

yk=A(ωkn)=n1j=0ajωkjn

向量 y=(y0,y1,...yn1) 就是系数向量a=( a0,a1,...,an1 )的离散傅里叶变换DFT

FFT

通过使用快速傅里叶变换的方法,利用复数单位根的特殊性质,我们就可以在 θ(nlgn) 的时间内计算出DFT(a)。

首先分别定义两个新的次数界为n/2的多项式

A[0](x)=a0+a2+...an2xn/21

A[1](x)=a1+a3+...an1xn/21

分别包含了所有偶数下标的系数和奇数下标的系数。

A(x)=A[0](X2)+xA[1](x2)

因而求A(x)在 ω0n,ω1n,,ωn1n 处的值得问题转化为:

  1. 求次数界为n/2的多项式 A[0](x)+xA[1](x) 在点 (ω0)2(ωn1)2 处的取值

  2. 用递归方法计算fft的伪代码如下

     RECURSIVE_FFT(a[])
      {
        n=a.lenth
        if(n==1) return a
        wn=e^(2*pi*i/n)
        w=1
        a0[]=(a0,a2,a4...)
        a1[]=(a1,a3,a5...)
        y0[]=RECURSIVE_FFT(a0[])
        y1[]=RECURSIVE_FFT(a1[])
        for k=0 to n/2
          y[k]=y0[k]+w*y1[k]
          y[k+n/2]=y0[k]-w*y1[k]
          w=w*wn
        return y
      }
  3. 计算出逆DFT。将fft算法进行修改,将a与y互换,用 ω1nωn ,并将计算结果的每个数除以n。

算法复杂度分析以及可能的优化

T(n)=2T(n/2)+θ(n)=θ(nlgn)

从算法实现上来看,整体的时间复杂度无法进行优化。

但是可以把递归的算法改成迭代的形式实现。从而节省栈中的空间。同时迭代算法可以做到常数上的优化。

优化实现主要代码

int rev(int k,int n){
    int res=0;
    while(n){
        int x=k&1;
        res=res*2+x;
        k>>=1;
        n>>=1;
    }
    return res;
}

void bit_reverse(vector a,vector A){
    int n=(int)a.size();
    A.resize(n);
    for(int k=0;k1)]=a[k];
    }
}

vector iterative_fft(vector a,double op){
    vector A;
    bit_reverse(a, A);
    int n=(int)a.size();
    for(int s=0;(1<int m=1<int temp=2*m;
        Complex wm=Complex(cos(pi/m*op), sin(pi/m*op));
        for(int k=0;k1, 0);
            for(int j=0;jreturn A;
}

整体实现代码

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

#define _ sync_with_stdio(false)
typedef long long ll;
typedef complex<double> Complex;
const double pi=acos(-1);
const int INF=0x7fffffff;

vector recursive_fft(vector a,double op){
    vector y;
    int n=(int)a.size();
    if(n==1)
        return a;
    Complex w=Complex(1,0);
    Complex wn=Complex(cos(2*pi/(n*op)),sin(2*pi/(n*op)));
    vector a0,a1;
    for(int i=0;iif(i&1){
            a1.push_back(a[i]);
        }else{
            a0.push_back(a[i]);
        }
    }
    vector y0=recursive_fft(a0, op);
    vector y1=recursive_fft(a1, op);
    y.resize(n);
    for(int k=0;k<=n/2-1;k++){
        y[k]=y0[k]+w*y1[k];
        y[k+n/2]=y0[k]-w*y1[k];
        w=w*wn;
    }
    return y;
}


int rev(int k,int n){
    int res=0;
    while(n){
        int x=k&1;
        res=res*2+x;
        k>>=1;
        n>>=1;
    }
    return res;
}

void bit_reverse(vector a,vector A){
    int n=(int)a.size();
    A.resize(n);
    for(int k=0;k1)]=a[k];
    }
}

vector iterative_fft(vector a,double op){
    vector A;
    bit_reverse(a, A);
    int n=(int)a.size();
    for(int s=0;(1<int m=1<int temp=2*m;
        Complex wm=Complex(cos(pi/m*op), sin(pi/m*op));
        for(int k=0;k1, 0);
            for(int j=0;jreturn A;
}


int main() {
    int n,m;
    cout<<"请输入第一个多项式的长度:";
    cin>>n;
    cout<<"请依次输入第一个多项式的系数:";
    vector s0,s1;
    for(int i=0;idouble x;
        cin>>x;
        s0.push_back(Complex(x,0));
    }
    cout<<"请输入第二个多项式的长度:";
    cin>>m;
    cout<<"请依次输入第二个多项式的系数:";
    for(int i=0;idouble x;
        cin>>x;
        s1.push_back(Complex(x,0));
    }
    //在容器后面补0使得其的阶变为2的幂次
    int MAX=max(n,m);
    int temp=1;
    while(temp1;
    }
    MAX=temp*2;
    //cout<
    for(int i=n;i0,0));
    for(int i=m;i0,0));
    //计算DFT
    vector  ans0=recursive_fft(s0, 1);
    vector ans1=recursive_fft(s1, 1);
    vector ans;
    ans.resize(MAX);
    for(int i=0;ivector res=recursive_fft(ans, -1);
    cout<<"相乘之后结果序列表示为:";
    for(int i=0;i1;i++){
        //cout<<"res[i]="<
        if(i==n+m-2)
            cout<else
            cout<" ";
    }
}

你可能感兴趣的:(算法,FFT,算法导论,fft,算法,分治算法)