浅谈算法——从多项式乘法到FFT

多项式乘法

我们知道,多项式可以表示成:

A=i=0naixi A = ∑ i = 0 n a i x i
的形式。
对于两个多项式 A(x) A ( x ) B(x) B ( x ) ,我们可以计算乘积 AB A ⋅ B
AB=i=0sizeAj=0sizeBaibjxi+j A ⋅ B = ∑ i = 0 s i z e A ∑ j = 0 s i z e B a i b j x i + j

但是,这样算是 O(sizeAsizeB) O ( s i z e A ⋅ s i z e B ) 的,太慢了,怎么办?
我们需要换一条思路。

首先,我们得知道一个东西:多项式的点值表示法
我们把上面的称为多项式的系数表示法,而点值表示法就是:
A A 多项式的次数为 n n ,则任取 n n 个不相同的 x0,x1,,xn x 0 , x 1 , ⋯ , x n ,求出 A A 多项式的 A(x0),A(x1),,A(xn) A ( x 0 ) , A ( x 1 ) , ⋯ , A ( x n ) 。记为:

<(x0,A(x0)),(x1,A(x1)),,(xn,A(xn))> < ( x 0 , A ( x 0 ) ) , ( x 1 , A ( x 1 ) ) , ⋯ , ( x n , A ( x n ) ) >
显然,一个有 n+1 n + 1 个点的点对唯一表示一个 n n 次多项式。

对于一个点值表示法下多项式

<(x0,A(x0)),(x1,A(x1)),,(xn,A(xn))> < ( x 0 , A ( x 0 ) ) , ( x 1 , A ( x 1 ) ) , ⋯ , ( x n , A ( x n ) ) >
<(x0,B(x0)),(x1,B(x1)),,(xn,B(xn))> < ( x 0 , B ( x 0 ) ) , ( x 1 , B ( x 1 ) ) , ⋯ , ( x n , B ( x n ) ) >
它们的乘积是
<(x0,A(x0)B(x0)),(x1,A(x1)B(x1)),,(xn,A(xn)B(xn))> < ( x 0 , A ( x 0 ) ⋅ B ( x 0 ) ) , ( x 1 , A ( x 1 ) ⋅ B ( x 1 ) ) , ⋯ , ( x n , A ( x n ) ⋅ B ( x n ) ) >
可以看出点值表示法的多项式相乘是 O(n) O ( n ) 的。

等等,我们好像找到了一个突破口!
为啥不把原来的系数表示法下的多项式转化成点值表示法呢?
仔细想一想:系数表示法与点值表示法互相转换,这个步骤好像是 O(n2) O ( n 2 ) 的。
FFT(快速傅里叶变换)就是为了优化这个 O(n2) O ( n 2 )

PS:对于 O(n2) O ( n 2 ) 的点值表示法转化成系数表示法可以看百度百科中关于插值法的介绍。

FFT(快速傅里叶变换)

如果未特别说明,那么下面的多项式次数将是 2k1 2 k − 1 的形式。
如果不是关键部分的公式或定理,不提供证明,自己出门右转百度。

首先介绍两个概念:
DFT(离散傅里叶变换)是将多项式由系数表示法转化成点值表示法;
IDFT(离散傅里叶逆变换)是将多项式由点值表示法转化成系数表示法;
而FFT就是上述两种变换的优化。

DFT部分

前置技能:
下面的内容将会提到复数,不会的可以参考百度百科中关于复数的介绍;

为了介绍FFT中的DFT部分,首先要介绍的是一个概念:单位根
单位根:若有

zn=1 z n = 1
此时将 z z 称为 n n 次单位根。
若有 zR z ∈ R ,显然, z z 可以等于 1 1 ,如果 n n 是偶数,则 z z 还可以等于 1 − 1
我们把范围扩大到 zC z ∈ C ,那么,我们可以得到 n n 个复数,它们将复平面上的单位圆等分成 n n 份。

为了表示 n n 次单位根,我们引入一个公式。
欧拉公式:

ein=cosn+isinn e i n = cos ⁡ n + i sin ⁡ n

如果我们令:

ωn=e2πi/n ω n = e 2 π i / n
那么, n n 次单位根就可以表示成 ω0n,ω1n,,ωn1n ω n 0 , ω n 1 , ⋯ , ω n n − 1 ,它们的 n n 次方显然都是 1 1

下面是关于 ωn ω n 的两条性质:(都是在 n n 为偶数的情况下)

ωn/2n=e(2πi/n)(n/2)=eπi=cosπ+isinπ=1(1) (1) ω n n / 2 = e ( 2 π i / n ) ⋅ ( n / 2 ) = e π i = c o s π + i sin ⁡ π = − 1

ω2n=e22πi/n=ωn/2(2) (2) ω n 2 = e 2 ⋅ 2 π i / n = ω n / 2

下面,我们进入正题:DFT的求法
在这里,我们令多项式次数为 n1 n − 1 ,那么我们可以用点值表示成

<(ω0n,A(ω0n),(ω1n,A(ω1n)),,(ωn1n,A(ωn1n))> < ( ω n 0 , A ( ω n 0 ) , ( ω n 1 , A ( ω n 1 ) ) , ⋯ , ( ω n n − 1 , A ( ω n n − 1 ) ) >

额……这时间复杂度好像并没有减少……
别急,我们来看 A(ωkn) A ( ω n k ) 能够表示成什么。

A(ωkn)====i=0n1aiωkini=0n/21a2iω2kin+i=0n/21a2i+1ω2ki+kni=0n/21a2iωkin/2+i=0n/21a2i+1ωkin/2ωkni=0n/21a2iωkin/2+ωni=0n/21a2i+1ωkin/2(3)(4)(5)(6) (3) A ( ω n k ) = ∑ i = 0 n − 1 a i ω n k i (4) = ∑ i = 0 n / 2 − 1 a 2 i ω n 2 k i + ∑ i = 0 n / 2 − 1 a 2 i + 1 ω n 2 k i + k (5) = ∑ i = 0 n / 2 − 1 a 2 i ω n / 2 k i + ∑ i = 0 n / 2 − 1 a 2 i + 1 ω n / 2 k i ω n k (6) = ∑ i = 0 n / 2 − 1 a 2 i ω n / 2 k i + ω n ∑ i = 0 n / 2 − 1 a 2 i + 1 ω n / 2 k i

我们来分别看一看这神奇的步骤。
(3) ( 3 ) 步骤就是将 ωkn ω n k 带入原来的 A A 多项式。
(4) ( 4 ) 步骤就是将原多项式拆成两个部分,按奇偶分类。
(5) ( 5 ) 步骤用到了上面提到的性质 (2) ( 2 )
(6) ( 6 ) 步骤就是上面式子的后半部分提出公因数。

有了这个等式,我们就可以分治+递归解决DFT了。
算法步骤:

  • 对当前的多项式(一个数组)系数进行奇偶分类;
  • 递归算出偶数部分的数组的 anse a n s e 和奇数部分的数组的 anso a n s o
  • 这个多项式的 ans=anse+ωnanso a n s = a n s e + ω n a n s o

但是这个的常数好像很大啊?能不能减少一点呢?
上面的性质 (1) ( 1 ) 给了我们提示:

ωn/2+kn=ωn/2nωkn=ωkn ω n n / 2 + k = ω n n / 2 ⋅ ω n k = − ω n k

在算 k<n2 k < n 2 时,可以顺便把 kn2 k ≥ n 2 的情况也算出来。

常数减小了一半!但是还是很大啊!
递归版的程序一般比非递归版慢,为啥不用非递归版呢?

算法核心就是奇偶分类,分来分去最后分到了哪里?我们来研究研究。
显然,一个序列原来是 0,1,2,3,4,5,6,7 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,最终变成 0,4,2,6,1,5,3,7 0 , 4 , 2 , 6 , 1 , 5 , 3 , 7
把它们的二进制列出来:

000,001,010,011,100,101,110,111000,100,010,110,001,101,011,111 000 , 001 , 010 , 011 , 100 , 101 , 110 , 111 000 , 100 , 010 , 110 , 001 , 101 , 011 , 111

其中,上面是位置,下面是这个位置对应的数。

把上面的数翻转,好像就是下面的数!
没错,只需要计算一下每个数的二进制翻转后的结果,就能得到一个数最终对应的位置,也就能实现非递归版了。

代码:

int dft_fast(complex* ar,int len)
{
  for(register int i=0; iif(rev[i]std::swap(ar[rev[i]],ar[i]);//交换一个位置和它的翻转后位置
        }
    }
  for(register int i=2; i<=len; i<<=1)//i代表当前序列的长度
    {
      complex wn(cos(2*pi/i),sin(2*pi/i));//omega_n
      for(register int j=0; j//j代表序列的起始位置
        {
          complex w(1,0);//下面代表omega_n^k
          for(register int k=0; k<(i>>1); ++k)//枚举i次单位根的每一种取值
            {
              complex x=ar[j+k],y=w*ar[j+k+(i>>1)];
              ar[j+k]=x+y;//合并操作,将两边合并成一个点值表示法下的多项式
              ar[j+k+(i>>1)]=x-y;
              w=w*wn;
            }
        }
    }
  return 0;
}

IDFT部分

回顾上面的DFT部分,仔细思考一下,它本质就是在求:

a0(ω0n)0+a1(ω0n)1++an1(ω0n)n1=A(ω0n)a0(ω1n)0+a1(ω1n)1++an1(ω1n)n1=A(ω1n)a0(ωn1n)0+a1(ωn1n)1++an1(ωn1n)n1=A(ωn1n) { a 0 ( ω n 0 ) 0 + a 1 ( ω n 0 ) 1 + ⋯ + a n − 1 ( ω n 0 ) n − 1 = A ( ω n 0 ) a 0 ( ω n 1 ) 0 + a 1 ( ω n 1 ) 1 + ⋯ + a n − 1 ( ω n 1 ) n − 1 = A ( ω n 1 ) ⋯ a 0 ( ω n n − 1 ) 0 + a 1 ( ω n n − 1 ) 1 + ⋯ + a n − 1 ( ω n n − 1 ) n − 1 = A ( ω n n − 1 )

其中,给定了 a0,a1,,an1 a 0 , a 1 , ⋯ , a n − 1 以及 ω0n,ω1n,,ωn1n ω n 0 , ω n 1 , ⋯ , ω n n − 1 ,求 A(ω0n),A(ω1n),,A(ωn1n) A ( ω n 0 ) , A ( ω n 1 ) , ⋯ , A ( ω n n − 1 ) 的值。

用矩阵表示如下:

(ω0n)0(ω1n)0(ωn1n)0(ω0n)1(ω1n)1(ωn1n)1(ω0n)n1(ω1n)n1(ωn1n)n1a0a1an1=A(ω0n)A(ω1n)A(ωn1n)(7) (7) [ ( ω n 0 ) 0 ( ω n 0 ) 1 ⋯ ( ω n 0 ) n − 1 ( ω n 1 ) 0 ( ω n 1 ) 1 ⋯ ( ω n 1 ) n − 1 ⋮ ⋮ ⋱ ⋮ ( ω n n − 1 ) 0 ( ω n n − 1 ) 1 ⋯ ( ω n n − 1 ) n − 1 ] [ a 0 a 1 ⋮ a n − 1 ] = [ A ( ω n 0 ) A ( ω n 1 ) ⋮ A ( ω n n − 1 ) ]

我们令:

V=(ω0n)0(ω1n)0(ωn1n)0(ω0n)1(ω1n)1(ωn1n)1(ω0n)n1(ω1n)n1(ωn1n)n1 V = [ ( ω n 0 ) 0 ( ω n 0 ) 1 ⋯ ( ω n 0 ) n − 1 ( ω n 1 ) 0 ( ω n 1 ) 1 ⋯ ( ω n 1 ) n − 1 ⋮ ⋮ ⋱ ⋮ ( ω n n − 1 ) 0 ( ω n n − 1 ) 1 ⋯ ( ω n n − 1 ) n − 1 ]

那么IDFT的本质就是求 V V 矩阵的逆矩阵。

考虑下面这个矩阵:

D=(ω0n)0(ω1n)0(ω(n1)n)0(ω0n)1(ω1n)1(ω(n1)n)1(ω0n)n1(ω1n)n1(ω(n1)n)n1 D = [ ( ω n − 0 ) 0 ( ω n − 0 ) 1 ⋯ ( ω n − 0 ) n − 1 ( ω n − 1 ) 0 ( ω n − 1 ) 1 ⋯ ( ω n − 1 ) n − 1 ⋮ ⋮ ⋱ ⋮ ( ω n − ( n − 1 ) ) 0 ( ω n − ( n − 1 ) ) 1 ⋯ ( ω n − ( n − 1 ) ) n − 1 ]

那么我们令 E=DV E = D ⋅ V ,则:

Ei,j===k=0n1Di,kVk,jk=0n1(ωin)k(ωkn)jk=0n1ωk(ji)n E i , j = ∑ k = 0 n − 1 D i , k V k , j = ∑ k = 0 n − 1 ( ω n − i ) k ( ω n k ) j = ∑ k = 0 n − 1 ω n k ( j − i )

显然:

Ei,j={0(ij)n(i=j) E i , j = { 0 ( i ≠ j ) n ( i = j )

因此,

1nDV=1nE=In 1 n D ⋅ V = 1 n E = I n

(7) ( 7 ) 式两边同时左乘 1nD 1 n D ,可得
a0a1an1=1n(ω0n)0(ω1n)0(ω(n1)n)0(ω0n)1(ω1n)1(ω(n1)n)1(ω0n)n1(ω1n)n1(ω(n1)n)n1A(ω0n)A(ω1n)A(ωn1n) [ a 0 a 1 ⋮ a n − 1 ] = 1 n [ ( ω n − 0 ) 0 ( ω n − 0 ) 1 ⋯ ( ω n − 0 ) n − 1 ( ω n − 1 ) 0 ( ω n − 1 ) 1 ⋯ ( ω n − 1 ) n − 1 ⋮ ⋮ ⋱ ⋮ ( ω n − ( n − 1 ) ) 0 ( ω n − ( n − 1 ) ) 1 ⋯ ( ω n − ( n − 1 ) ) n − 1 ] [ A ( ω n 0 ) A ( ω n 1 ) ⋮ A ( ω n n − 1 ) ]

这就相当于把DFT中 ωkn ω n k 都换成 ωkn ω n − k

FFT总代码

int fft(complex* ar,int len,int op)
{
  for(register int i=0; iif(rev[i]std::swap(ar[rev[i]],ar[i]);
        }
    }
  for(register int i=2; i<=len; i<<=1)
    {
      complex wn(cos(2*pi/i),sin(2*pi*op/i));//只有这里较DFT代码有变动
      for(register int j=0; jcomplex w(1,0);
          for(register int k=0; k<(i>>1); ++k)
            {
              complex x=ar[j+k],y=w*ar[j+k+(i>>1)];
              ar[j+k]=x+y;
              ar[j+k+(i>>1)]=x-y;
              w=w*wn;
            }
        }
    }
  if(op==-1)
    {
      for(register int i=0; ireturn 0;
}

多项式乘法模板

#include 
#include 
#include 

const int maxn=100000;
const double pi=acos(-1);

struct complex
{
  double r,i;

  complex(double r_=0,double i_=0)
  {
    r=r_;
    i=i_;
  }

  complex operator +(const complex &other)
  {
    return complex(r+other.r,i+other.i);
  }

  complex operator -(const complex &other)
  {
    return complex(r-other.r,i-other.i);
  }

  complex operator *(const complex &other)
  {
    return complex(r*other.r-i*other.i,r*other.i+i*other.r);
  }
};

complex a[maxn<<2],b[maxn<<2],c[maxn<<2];
int rev[maxn<<2],n,m;

int fft(complex* ar,int len,int op)
{
  for(register int i=0; iif(rev[i]std::swap(ar[rev[i]],ar[i]);
        }
    }
  for(register int i=2; i<=len; i<<=1)
    {
      complex wn(cos(2*pi/i),sin(2*pi*op/i));
      for(register int j=0; jcomplex w(1,0);
          for(register int k=0; k<(i>>1); ++k)
            {
              complex x=ar[j+k],y=w*ar[j+k+(i>>1)];
              ar[j+k]=x+y;
              ar[j+k+(i>>1)]=x-y;
              w=w*wn;
            }
        }
    }
  if(op==-1)
    {
      for(register int i=0; ireturn 0;
}

int main()
{
  scanf("%d%d",&n,&m);
  for(register int i=0; i<=n; ++i)
    {
      scanf("%lf",&a[i].r);
    }
  for(register int i=0; i<=m; ++i)
    {
      scanf("%lf",&b[i].r);
    }
  n=n+m;
  int l=0;
  m=1;
  while(m<=n)
    {
      ++l;
      m<<=1;
    }
  for(register int i=0; i>1]>>1)|((i&1)<<(l-1));
    }
  fft(a,m,1);
  fft(b,m,1);
  for(register int i=0; i1);
  for(register int i=0; iprintf("%d ",(int)(c[i].r+0.5));
    }
  printf("%d\n",(int)(c[n].r+0.5));
  return 0;
}

你可能感兴趣的:(FFT)