多项式
对于多项式$ f\left(x\right)=\sum_{i=0}^{|f|}{f_ix^i} $,其中|f|表示多项式的阶数,fi表示多项式f中x^i的系数。
多项式的加法定义为$ c\left(x\right)=a\left(x\right)+b\left(x\right)=\sum_{i=0}^{\max\left(|a|,|b|\right)}{\left(a_i+b_i\right)x^i} $,即$ c_k=a_k+b_k $。
多项式的乘法定义为$ c\left(x\right)=a\left(x\right)\cdot b\left(x\right)=\sum_{k=0}^{|a|+|b|}{\left(\sum_{i+j=k}^{}{a_ib_j}\right)x^k} $,即$ c_k=\sum_{i+j=k}^{}{a_ib_j} $。
显然要计算两个多项式a(x),b(x)的乘积,程序的时间复杂度为O(|a||b|)。
naive(a, b) c = 0 for(i = 0; i <= |a|; i++) for(j = 0; j <= |b|; j++) c[i + j] = c[i + j] + a[i] * b[j]
在多项式阶数超过10w的时候,这个方法就完全顶不住了。不过幸好还有很多加快多项式乘运算速度的算法,而快速傅立叶变换就是其中之一。
先了解一下多项式的其它操作的时间复杂度:多项式的乘法虽然很慢,但是求解一个多项式f在x=x0的时候的取值f(x0)是可以在O(|f|)时间复杂度内做到的。以及多项式的加法a(x)+b(x)也可以在O(max(|a|,|b|))的时间复杂度内做到。
再了解一下多项式的表示方法:多项式的表示方法基本有两种,一种是通过系数序列(f0,f1,...,fk)来表示一个k阶多项式f,这种方法称为系数表示法,还有一种就是点值表示法,即用k+1个不同的点来表示一个k阶多项式。系数表示法大家都很了解,下面说一下点值表示法。
在代数中,说明过n个不同的点可以唯一确定一个k阶多项式,其中k 点值表示法非常适合用于计算多项式乘法,对于多项式乘法c(x)=a(x)*b(x),假设我们已经确认了|a|+|b|+1个不同的x值x0,x1,...,且分别计算出了a(x0),b(x0),a(x1),b(x1),...,那么我们就得知点(xi,a(xi)b(xi))是多项式c上的点,而这组点的数目为|a|+|b|+1>|c|,故c被这组点唯一确认,在前提下多项式乘法可以以O(|c|)的时间复杂度运行。 当然点值表达式的前提并不好满足,我们往往需要先通过插值取回多项式,之后再计算在额外的点的多项式值,之后再利用点值乘法算出新的多项式的点值表达式。这整个过程的时间复杂度为O(|c|^2)。 我们设n为大于等于c长度的最小2的幂次(即n=2^k>=|c|>2^(k-1)),在运算多项式乘法前,我们先将a与b通过前面补0将长度扩充到n,之后再运行多项式乘法。下面我们说明如何在O(nlog2n)的时间复杂度内将长度为n的多项式从系数表达式转换为点值表达式,并在O(nlog2n)的时间复杂度内将n个不同的点插值会系数表达式的多项式,而这一算法就是快速傅立叶变换,很显然这一过程的时间复杂度为O(nlog2n+n+nlog2n)=O(nlog2n)。 首先我们要谨慎地选取n个点的x值。在复数域中,n次单位复数根是满足w^n=1的所有复数w。由欧拉公式$ e^{iu}=\cos\left(u\right)+i\sin\left(u\right) $可知n次复数根分别为复数$ e^{\left(2\pi k/n\right)i} $,其中k分别取值0,1,...,n-1的,我们记为w(n,0),w(n,1),...,w(n,n-1)。很显然w(n,i)w(n,j)=w(n,i+j),由于w(n,0)=w(n,n)=1,故我们得知w(n,i)=w(n,n-i)=w(n,i-n)。下面说明这n个复数的有趣性质: 消去引理:w(dn,dk)=w(n,k)。 证明:略 折半引理:2n次单位复数根的平方组成的集合与n次单位复数根组成的集合相同。 证明:对于0<=k 对于n<=k<2n,由于w(2n,n)=-1(看欧拉公式),故$ w\left(2n,k\right)^2=\left(-w\left(2n,k-n\right)\right)^2=w\left(2n,k-n\right)^2=w\left(n,k-n\right) $。 求和引理:对于任意n>=1和不能被n整除的非负整数k,有$ \sum_{j=0}^{n-1}{w\left(n,k\right)^j}=0 $。 证明:$ \sum_{j=0}^{n-1}{w\left(n,k\right)^j}=\frac{w\left(n,k\right)^n-1}{w\left(n,k\right)-1}=\frac{1-1}{w\left(n,k\right)-1}=0 $。 对于一个长度为n的多项式f(x),其在n个n次单位复数根w(n,0),w(n,1),...,w(n,n-1)上的取值组成的序列y=(f(w(n,0)),f(w(n,1)),...,f(w(n,n-1)))称为f的离散傅立叶变换(DFT),也记作y=DFT(f)。很显然DFT(f)可以唯一确定f。 要计算DFT(f),我们可以分治策略。 我们将f切分为长度为n/2的两个多项式even和odd,其中even中仅包含多项式的偶数项系数,而odd中仅包含奇数项系数: even=f0*x^0+f2*x^1+f4*x^2+...+fn-2x^(n/2-1) odd=f1*x^0+f3*x^1+...+fn-1x^(n/2-1) 而很显然f=even(x^2)+x*odd(x^2)。因此要计算f在w(n,0),w(n,1),...,w(n,n-1)上的取值,只需要计算even和odd在w(n,0)^2,w(n,1)^2,...,w(n,n-1)^2上的取值即可。由折半引理可知{w(n,0)^2,w(n,1)^2,...,w(n,n-1)^2}={w(n/2,0),w(n/2,1),...,w(n/2,n/2-1)},故我们实际上要计算的仅为DFT(even)和DFT(odd)。在得到DFT(even)和DFT(odd)后,仅需使用O(n)的时间复杂度即可算出DFT(f)。 我们记T(f)表示DFT(f)的时间复杂度,则T(f)=O(n)+2T(f/2)=...=O(kn)+(2^k)*T(f/(2^k))=...=O(nlog2(n))。 由于计算机计算正弦和余弦函数要花费大量的时间,因此可以将wn1作为参数传入,而在计算DFT(even)时,将wn1*wn1作为参数传入即可(w(n,1)^2=w(n/2,1)),这样可以节省时间。 离散傅立叶逆变换IDFT将多项式从点值表达式转换为系数表达式。即IDFT(DFT(c))=c。观察下面等式: $$ \left[\begin{array}{c} y_0\\ y_1\\ \vdots\\ y_{n-1} \end{array}\right]=\left[\begin{matrix} w\left(n,0\right)^0 & w\left(n,0\right)^1 &\cdots & w\left(n,0\right)^{n-1}\\ w\left(n,1\right)^0 & w\left(n,1\right)^1 &\cdots & w\left(n,1\right)^{n-1}\\ \vdots &\vdots &\ddots &\vdots\\ w\left(n,n-1\right)^0 & w\left(n,n-1\right)^1 &\cdots & w\left(n,n-1\right)^{n-1} \end{matrix}\right]\left[\begin{array}{c} c_0\\ c_1\\ \vdots\\ c_{n-1} \end{array}\right] $$ 等式左边为DFT(c),而做右边的向量为c,方阵为范德蒙德矩阵,由于w(n,0),...,w(n,n-1)互不相同(欧拉公式),故方阵可逆。我们将等式简记y=Mc。记IM=M^(-1)。 定理:IM的j'行j列元素为IMj'j=w(n,-j'j)/n 证明:记P=IM·M,显然$ P_{j'j}=\sum_{k=0}^{n-1}{w\left(n,-j'k\right)/n\cdot w\left(n,kj\right)}=\frac{1}{n}\sum_{k=0}^{n-1}{w\left(n,k\right)^{j-j'}} $。当j等于j'时,Pj'j=n/n=1,而当j不等于j'时,由求和引理可知Pj'j=0。故P是单位矩阵,因此IM=M^(-1)。 现在我们可以利用c=IM·y来计算c了。观察矩阵下隐含的等式关系: $$ c_j=\sum_{k=0}^{n-1}{y_k\cdot w\left(n,-jk\right)/n}=\frac{1}{n}\sum_{k=0}^{n-1}{y_kw\left(n,n-j\right)^k} $$ 这给了我们一个启发,c和DFT(y)/n中包含的值是相同的,只是顺序不同而已。因此我们可以利用DFT以及一些常数时间的操作实现IDFT。 IDFT和DFT共享相同的时间复杂度O(nlog2n)。 快速傅立叶变换利用DFT和IDFT计算两个多项式a(x)和b(x)的乘积。 FFT的第一步首先是找到一个合适的二次幂n,并将a和b通过添加0系数项的方式扩展到长度n。之后计算DFT(a)和DFT(b),在利用点乘计算DFT(c)=DFT(a)·DFT(b)。最后利用IDFT从点值表达式DFT(c)复原出多项式c,即c=FFT(a,b)=IDFT(DFT(a)·DFT(b))。 显然FFT的时间复杂度为DFT和IDFT的总和,也是O(nlog2n)。而由于n<2*|c|=2*(|a|+|b|),因此我们可以忽略n与|c|之间的误差,即时间复杂度可以写作O(|c|log2|c|)。是相当优秀的时间复杂度。 实际中,如果你直接使用上面的快速傅立叶变换,你将会发现性能相当不理想。我们仔细观察一下FFT的流程,我们能发现我们总共做了三次DFT计算(其中一次在调用IDFT中),和一次点乘运算,其中点乘运算是线性时间复杂度,而DFT的时间复杂度为O(nlogn)。如果我们能优化DFT,那么FFT的性能将得到最大幅度的提升。 那么DFT的缺陷在哪里,我们很容易发现每次我们都需要建立一个多项式y,以保存最终结果,我们完全可以在FFT中复制多项式a,b,之后在DFT中将结果直接写入到f中并作为返回值。之后我们还可以发现,每次用f调用DFT,需要创建两个子多项式odd和even,其长度总和为|f|。因此我们总共分配的多项式长度为O(nlog2n),这是相当大的空间复杂度。而如果我们能不用分配even和odd,那么自然能达到优化DFT的结果。 观察对于多项式(a0,a1,a2,a3,a4,a5,a6,a7)调用DFT的过程: (a0,a1,a2,a3,a4,a5,a6,a7) | \ (a0,a2,a4,a8) (a1,a3,a5,a7) | \ | \ (a0,a4) (a2,a8) (a1,a5) (a3,a7) | \ | \ | \ | \ a0 a4 a2 a8 a1 a5 a3 a7 如果我们能将多项式(a0,a1,a2,a3,a4,a5,a6,a7)在DFT的一开始转换为(a0,a4,a2,a8,a1,a5,a3,a7),那么我们就可以用非递归的方式实现DFT,一开始将长度为1的两个相邻多项式合并,之后将长度为2的两个相邻多项式合并,再将长度为4两个相邻多项式合并。 观察序列,能发现以下规律,下标二进制最低位为0的排在二进制最低位为1的之前。如果最低位相同,则用相同逻辑判断次低位。容易发现这相当于比较两个整数二进制(总共log2n位)逆序的大小。因此如果我们能一开始处理得到n个数值的二进制序列,之后利用逆序对多项式各位进行重排列,那么就能实现上述的非递归版本的DFT。 而要计算所有长度为m的二进制序列的所有逆序,可以用下面的方法: 其中我们使用了一个聪明的想法,i的二进制逆序,我们可以先将i右移1位得到j,由于j>1左位或处理,从而得到r[i]。很容易得出reverse的时间复杂度为O(n)。 最后给出非递归版本的DFT: 其中r=reverse(m)这个步骤我们可以提前执行。以及w1 = complex(cos(PI/s),sin(PI/s))中由于涉及到了cos和sin的计算,由于s始终是2的幂,因此可能的值非常少,也可以一开始就计算好。整个DFT中我们没有分配额外的空间,因此空间复杂度可以认作为O(n),而时间复杂度不变为O(nlog2n)。 单位复数根
离散傅立叶变换DFT
DFT(f)
n = f.length
if(n == 1)
return f[0]
y = empty-array
even = 0
odd = 0
for(i = 0; i < n / 2; i++)
even[i] = f[i * 2]
odd[i] = f[i * 2 + 1]
yEven = DFT(even)
yOdd = DFT(odd)
wn1 = e^((2*PI/n)i)
w = 1
for(i = 0; i < n / 2; i++)
y[i]= yEven[i] + w * yOdd[i]
y[i + n / 2] = yEven[i] - w * yOdd[i]
w = w * wn1
return y离散傅立叶逆变换IDFT
IDFT(y)
c = 0
n = y.length
dftY = DFT(y)
c[0] = dftY[0] / n
for(i = 1; i < n; i++)
c[i] = dftY[n - i] / n
return c
快速傅立叶变换FFT
FFT(a, b)
n = 1
while(n < |a| + |b|)
n = n * 2
extend(a, n)
extend(b, n)
return IDFT(DFT(a)·DFT(b))
非递归版本DFT
reverse(m)
n = 2^m
r = int[n]
r[0] = 0
for(i = 1; i < n; i++)
r[i] = (r[i >> 1] >> 1) | ((1 & i) << (m -1))
return r
DFT(p, m)
r = reverse(m)
n = 2^m
for(i = 0; i < n; i++)
if(r[i] > i)
swap(p, i, r[i])
for(d = 0; d < m; d++)
s = 2^d
w1 = complex(cos(PI/s),sin(PI/s))
for(i = 0; i < n; i += 2*s)
w = 1
for(j = 0; j < s; j++)
a = i + j
b = a + s
t = w * p[b]
q = p[a]
p[b] = q - t
p[a] = q + t
w = w * w1