fft学习小记

Preface

这东西要打多几遍。

FFT 快速傅里叶变换

核心思想:

利用单位复数根的性质,尝试分治,在 O ( n l o g n ) O(nlogn) O(nlogn)的时间内完成点值和插值运算。

单位复数根

( 1 , 0 ) (1,0) (1,0)出发,每次逆时针旋转 2 π n \frac{2π}{n} n2π弧得到一个新的复数根。所以单位复数根可以表示为: ω ( c o s ( 2 ∗ i ∗ p i / m ) , s i n ( 2 ∗ i ∗ p i / m ) ) \omega(cos(2*i*pi/m),sin(2*i*pi/m)) ω(cos(2ipi/m),sin(2ipi/m))

根据复数运算法则,可以知道,两个复数相乘,实际上是幅角相加,模长相乘。

所以有 ω n i + j = ω n i ∗ ω n j \omega_{n}^{i+j}=\omega_n^i*\omega_n^j ωni+j=ωniωnj

那么单位复数根就有以下性质:

  1. 折半引理: ( ω n k + n / 2 ) 2 = ω n 2 k + n = ω n 2 k ω n n = ω n 2 k = ( ω n k ) 2 (\omega_n^{k+n/2})^2=\omega_n^{2k+n}=\omega_n^{2k}\omega_n^n=\omega_n^{2k}=(\omega_n^k)^2 (ωnk+n/2)2=ωn2k+n=ωn2kωnn=ωn2k=(ωnk)2

这个引理保证了FFT分治策略的时间复杂度。

  1. 消去引理:

n n n不整除 k k k
∑ j = 0 n − 1 ( ω n k ) j = ( ω n k ) n − 1 ω n k − 1 = ( ω n n ) k − 1 ω n k − 1 = ( 1 ) k − 1 ω n k − 1 = 0 \sum_{j=0}^{n-1}(\omega_n^k)^j=\frac{(\omega_n^k)^n-1}{\omega_n^k-1}=\frac{(\omega_n^n)^k-1}{\omega_n^k-1}=\frac{(1)^k-1}{\omega_n^k-1}=0 j=0n1(ωnk)j=ωnk1(ωnk)n1=ωnk1(ωnn)k1=ωnk1(1)k1=0

n n n整除 k k k时,就有 ω n k = 1 \omega_{n}^k=1 ωnk=1,所以上式的值显然为 n n n.

这个引理推出FFT的插值运算与点值运算的关系,是插值运算的基础。

分治策略

考虑一个多项式 A ( x ) A(x) A(x),按照系数下标的奇偶来分成两个次数界为 n 2 \frac{n}{2} 2n的多项式 A 0 , A 1 A_0,A_1 A0,A1,即 A [ 0 ] ( x ) = a 0 + a 2 x + a 4 x 2 + . . . + a n − 2 x n / 2 − 1 A^{[0]}(x)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{n/2-1} A[0](x)=a0+a2x+a4x2+...+an2xn/21 A [ 1 ] ( x ) = a 1 + a 3 x + a 5 x 2 + . . . + a n − 1 x n / 2 − 1 A^{[1]}(x)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{n/2-1} A[1](x)=a1+a3x+a5x2+...+an1xn/21

然后可以写出: A ( x ) = A [ 0 ] ( x 2 ) + A [ 1 ] ( x 2 ) x A(x)=A^{[0]}(x^2)+A^{[1]}(x^2)x A(x)=A[0](x2)+A[1](x2)x

然后根据折半引理,可以分成两个子问题求解。

时间复杂度就是: T ( n ) = 2 T ( n / 2 ) + Θ ( n ) = Θ ( n l o g ( n ) ) T(n)=2T(n/2)+\Theta(n)=\Theta(nlog(n)) T(n)=2T(n/2)+Θ(n)=Θ(nlog(n))

逆DFT

实际上就是上面的逆运算。

根据矩阵的思想: [ y 0 y 1 y 2 y 3 ⋮ y n − 1 ] = [ 1 1 1 1 … 1 1 ω n ω n 2 ω n 3 … ω n n − 1 1 ω n 2 ω n 4 ω n 6 … ω n 2 ( n − 1 ) 1 ω n 3 ω n 6 ω n 9 … ω n 3 ( n − 1 ) ⋮ ⋮ ⋮ ⋮ ⋱ ⋮ 1 ω n n − 1 ω n 2 ( n − 1 ) ω n 3 ( n − 1 ) … ω n ( n − 1 ) ( n − 1 ) ] [ a 0 a 1 a 2 a 3 ⋮ a n − 1 ] \begin{bmatrix} y_0 \\ y_1 \\ y_2 \\ y_3 \\ \vdots \\ y_{n-1} \end{bmatrix}= \begin{bmatrix} 1 & 1 & 1 & 1 & \dots & 1 \\ 1 & \omega_n & \omega_n^2 & \omega_n^ 3 & \dots & \omega_n^{n-1} \\ 1 & \omega_n^2 & \omega_n^4 & \omega_n^6 & \dots & \omega_n^{2(n-1)} \\ 1 & \omega_n^3 & \omega_n^6 & \omega_n^9 & \dots & \omega_n^{3(n-1)} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \omega_n^{3(n-1)} & \dots & \omega_n^{(n-1)(n-1)} \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \vdots \\ a_{n-1} \end{bmatrix} y0y1y2y3yn1=111111ωnωn2ωn3ωnn11ωn2ωn4ωn6ωn2(n1)1ωn3ωn6ωn9ωn3(n1)1ωnn1ωn2(n1)ωn3(n1)ωn(n1)(n1)a0a1a2a3an1来求出新的 ω \omega ω取值。

可以证明就是: [ V n − 1 V n ] j j ′ = ∑ k = 0 n − 1 ( ω n − k j / n ) ( ω n k j ′ ) = ∑ k = 0 n − 1 ω n k ( j ′ − j ) / n [V_n^{-1}V_n]_{jj'}=\sum_{k=0}^{n-1}(\omega_n^{-kj}/n)(\omega_n^{kj'})=\sum_{k=0}^{n-1}\omega_n^{k(j'-j)}/n [Vn1Vn]jj=k=0n1(ωnkj/n)(ωnkj)=k=0n1ωnk(jj)/n
然后做一遍就好了。

Code

#include 

#define F(i,a,b) for (int i=a;i<=b;i++)
typedef double db;

const int M = 4 * 1e5 + 10;
const db pi = acos(- 1);

using namespace std;

int cnt, n, m, len;

struct Z {
	db x, y;
	Z (db _x=0, db _y=0) { x = _x, y = _y; }
	friend Z operator + (Z a, Z b) { return Z(a.x + b.x, a.y + b.y); }
	friend Z operator - (Z a, Z b) { return Z(a.x - b.x, a.y - b.y); }
	friend Z operator * (Z a, Z b) { return Z(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
} a[M], b[M], t[M];

void Dft(Z *a, int n, int sig) {
	F(i, 0, n - 1) {
		int pos = 0;
		for (int x = i, y = 0; y < cnt; y ++, x >>= 1) pos = (pos << 1) + (x & 1);
		t[pos] = a[i];
	}
	for (int m = 2; m <= n; m <<= 1) {
		int half = m >> 1;
		Z W(cos(2 * sig * pi / m), sin(2 * sig * pi / m));
		for (int i = 0; i < n; i += m) {
			Z o(1, 0);
			for (int j = i; j < i + half; j ++, o = o * W) {
				Z u = t[j + half] * o;
				t[j + half] = t[j] - u;
				t[j] = t[j] + u;
			}
		}
	}
	F(i, 0, n - 1)
		a[i] = t[i];
	if (sig == - 1) {
		F(i, 0, n - 1)
			a[i].x = int(a[i].x / n + 0.5);
	}
}

int main() {
	scanf("%d%d", &n, &m);
	F(i, 0, n) scanf("%lf", &a[i].x);
	F(i, 0, m) scanf("%lf", &b[i].x);
	
	for (len = 1, cnt = 0; len <= n + m + 1; len <<= 1, cnt ++);
	Dft(a, len, 1), Dft(b, len, 1);
	F(i, 0, len - 1) a[i] = a[i] * b[i];
	Dft(a, len, - 1);

	F(i, 0, n + m) printf("%0.lf ", a[i].x);
}

你可能感兴趣的:(fft学习小记)