【知识总结】快速傅里叶变换(FFT)

这可能是我第五次学FFT了……菜哭qwq

先给出一些个人认为非常优秀的参考资料:

一小时学会快速傅里叶变换(Fast Fourier Transform) - 知乎

小学生都能看懂的FFT!!! - 胡小兔 - 博客园

快速傅里叶变换(FFT)用于计算两个 n n n次多项式相乘,能把复杂度从朴素的 O ( n 2 ) O(n^2) O(n2)优化到 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)。一个常见的应用是计算大整数相乘。

本文中所有多项式默认 x x x为变量,其他字母均为常数。所有角均为弧度制。

一、多项式的两种表示方法

我们平时常用的表示方法称为“系数表示法”,即

A ( x ) = ∑ i = 0 n a i x i A(x)=\sum _{i=0}^n a_ix^i A(x)=i=0naixi

上面那个式子也可以看作一个以 x x x为自变量的 n n n次函数。用 n + 1 n+1 n+1个点可以确定一个 n n n次函数(自行脑补初中学习的二次函数)。所以,给定 n + 1 n+1 n+1 x x x和对应的 A ( x ) A(x) A(x),就可以求出原多项式。用 n + 1 n+1 n+1个点表示一个 n n n次多项式的方式称为“点值表示法”。

在“点值表示法”中,两个多项式相乘是 O ( n ) O(n) O(n)的。因为对于同一个 x x x,把它代入 A A A B B B求值的结果之积就是把它带入多项式 A × B A\times B A×B求值的结果(这是多项式乘法的意义)。所以把点值表示法下的两个多项式的 n + 1 n+1 n+1个点的值相乘即可求出两多项式之积的点值表示。

线性复杂度点值表示好哇好

但是,把系数表示法转换成点值表示法需要对 n + 1 n+1 n+1个点求值,而每次求值是 O ( n ) O(n) O(n)的,所以复杂度是 O ( n 2 ) O(n^2) O(n2)。把点值表示法转换成系数表示法据说也是 O ( n 2 ) O(n^2) O(n2)的(然而我只会 O ( n 3 ) O(n^3) O(n3)的高斯消元qwq)。所以暴力取点然后算还不如直接朴素算法相乘……

但是有一种神奇的算法,通过取一些具有特殊性质的点可以把复杂度降到 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)

二、单位根

从现在开始,所有 n n n都默认是 2 2 2的非负整数次幂,多项式次数为 n − 1 n-1 n1。应用时如果多项式次数不是 2 2 2的非负整数次幂减 1 1 1,可以加系数为 0 0 0的项补齐。

先看一些预备知识:

复数 a + b i a+bi a+bi可以看作平面直角坐标系上的点 ( a , b ) (a,b) (a,b)。这个点到原点的距离称为模长,即 a 2 + b 2 \sqrt{a^2+b^2} a2+b2 ;原点与 ( a , b ) (a,b) (a,b)所连的直线与实轴正半轴的夹角称为辐角,即 s i n − 1 b a sin^{-1}\frac{b}{a} sin1ab。复数相乘的法则:模长相乘,辐角相加

把以原点为圆心, 1 1 1为半径的圆(称为“单位圆”) n n n等分, n n n个点中辐角最小的等分点(不考虑 1 1 1)称为 n n n单位根,记作 ω n \omega_n ωn,则这 n n n个等分点可以表示为 ω n k ( 0 ≤ k < n ) \omega_n^k(0\leq k < n) ωnk(0k<n)

这里如果不理解,可以考虑周角是 2 π 2\pi 2π n n n次单位根的辐角是 2 π n \frac{2\pi}{n} n2π w n k = w n k − 1 × w n 1 w_n^k=w_n^{k-1}\times w_n^1 wnk=wnk1×wn1,复数相乘时模长均为 1 1 1,相乘仍为 1 1 1。辐角 2 π ( k − 1 ) n \frac{2\pi (k-1)}{n} n2π(k1)加上单位根的辐角 2 π n \frac{2\pi}{n} n2π变成 2 π k n \frac{2\pi k}{n} n2πk

单位根具有如下性质:

1.折半引理

w 2 n 2 k = w n k w_{2n}^{2k}=w_n^k w2n2k=wnk

模长都是 1 1 1,辐角 2 π × 2 k 2 n = 2 π k n \frac{2\pi \times 2k}{2n}=\frac{2\pi k}{n} 2n2π×2k=n2πk,故相等。

2.消去引理

w n k + n 2 = − w n k w_n^{k+\frac{n}{2}}=-w_n^k wnk+2n=wnk

这个从几何意义上考虑, w n k + n 2 w_n^{k+\frac{n}{2}} wnk+2n的辐角刚好比 w n k w_n^k wnk多了 2 π × n 2 n = π \frac{2\pi \times \frac{n}{2}}{n}=\pi n2π×2n=π,刚好是一个平角,所以它们关于原点中心对称。互为相反数的复数关于原点中心对称。

3.(不知道叫什么的性质)其中 k k k是整数

w n a + k n = w n a w_n^{a+kn}=w_n^a wna+kn=wna

这个也很好理解: w n n w_n^n wnn的辐角是 2 π 2\pi 2π,也就是转了一整圈回到了实轴正半轴上,这个复数就是实数 1 1 1。乘上一个 w n n w_n^n wnn就相当于给辐角加了一个周角,不会改变位置。

三、离散傅里叶变换(DFT)

DFT把多项式从系数表示法转换到点值表示法。

我们大力尝试把 n n n次单位根的 0 0 0 n − 1 n-1 n1次幂分别代入 n − 1 n-1 n1次多项式 A ( x ) A(x) A(x)。首先先对 A ( x ) A(x) A(x)进行奇偶分组,得到:

A 1 ( x ) = ∑ i = 0 n − 1 2 a 2 i ⋅ x i A_1(x)=\sum_{i=0}^{\frac{n-1}{2}}a_{2i}·x^i A1(x)=i=02n1a2ixi

A 2 ( x ) = ∑ i = 0 n − 1 2 a 2 i + 1 ⋅ x i A_2(x)=\sum_{i=0}^{\frac{n-1}{2}}a_{2i+1}·x^i A2(x)=i=02n1a2i+1xi

则有:

A ( x ) = A 1 ( x 2 ) + x ⋅ A 2 ( x 2 ) A(x)=A_1(x^2)+x·A_2(x^2) A(x)=A1(x2)+xA2(x2)

w n k w_n^k wnk代入,得:

A ( w n k ) = A 1 ( w n 2 k ) + w n k ⋅ A 2 ( w n 2 k ) A(w_n^k)=A_1(w_n^{2k})+w_n^k·A_2(w_n^{2k}) A(wnk)=A1(wn2k)+wnkA2(wn2k)

根据折半引理,有:

A ( w n k ) = A 1 ( w n 2 k ) + w n k ⋅ A 2 ( w n 2 k ) A(w_n^k)=A_1(w_{\frac{n}{2}}^k)+w_n^k·A_2(w_{\frac{n}{2}}^k) A(wnk)=A1(w2nk)+wnkA2(w2nk)

此时有一个特殊情况。当 n 2 ≤ k < n \frac{n}{2}\leq k < n 2nk<n,记 a = k − n 2 a=k-\frac{n}{2} a=k2n,则根据消去引理和上面第三个性质,有:

w n a = − w n k w_n^a=-w_n^k wna=wnk

w n 2 a = w n 2 k w_{\frac{n}{2}}^a=w_{\frac{n}{2}}^k w2na=w2nk

所以

A ( w n k ) = A 1 ( w n 2 a ) − w n a ⋅ A 2 ( w n 2 a ) A(w_n^k)=A_1(w_{\frac{n}{2}}^a)-w_n^a·A_2(w_{\frac{n}{2}}^a) A(wnk)=A1(w2na)wnaA2(w2na)

这样变换主要是为了防止右侧式子里出现 w n w_n wn的不同次幂。

按照这个式子可以递归计算。共递归 O ( l o g 2 n ) O(log_2n) O(log2n)层,每层需要 O ( n ) O(n) O(n)枚举 k k k,因此可以在 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)内把系数表示法变为点值表示法。

四、离散傅里叶反变换(IDFT)

w n k ( 0 ≤ k < n ) w_n^k(0\leq k<n) wnk(0k<n)代入多项式 A ( x ) A(x) A(x)后得到的点值为 b k b_k bk,令多项式 B ( x ) B(x) B(x)

B ( x ) = ∑ i = 0 n − 1 b i x i B(x)=\sum_{i=0}^{n-1}b_ix^i B(x)=i=0n1bixi

一个结论:设 w n − k ( 0 ≤ k < n ) w_n^{-k}(0\leq k<n) wnk(0k<n)代入 B ( x ) B(x) B(x)后得到的点值为 c k c_k ck,则多项式 A ( x ) A(x) A(x)的系数 a k = c k n a_k=\frac{c_k}{n} ak=nck。下面来证明这个结论。

c k = ∑ i = 0 n − 1 b i ⋅ w n − i k = ∑ i = 0 n − 1 ∑ j = 0 n − 1 a j ⋅ w n i j ⋅ w n − i k = ∑ j = 0 n − 1 a j ∑ i = 0 n − 1 w n i ( j − k ) \begin{aligned} c_k&=\sum_{i=0}^{n-1}b_i·w_n^{-ik}\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j·w_n^{ij}·w_n^{-ik}\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}w_n^{i(j-k)} \end{aligned} ck=i=0n1biwnik=i=0n1j=0n1ajwnijwnik=j=0n1aji=0n1wni(jk)

脑补一下 ∑ i = 0 n − 1 w n i ( j − k ) \sum_{i=0}^{n-1}w_n^{i(j-k)} i=0n1wni(jk)怎么求。可以看出这是一个公比为 w n j − k w_n^{j-k} wnjk的等比数列。

j = k j=k j=k w n 0 = 1 w_n^0=1 wn0=1,所以上式的值是 n n n

否则,根据等比数列求和公式,上式等于 w n j − k ⋅ w n n ( j − k ) − 1 w n j − k − 1 w_n^{j-k}·\frac{w_n^{n(j-k)}-1}{w_n^{j-k}-1} wnjkwnjk1wnn(jk)1 w n n ( j − k ) w_n^{n(j-k)} wnn(jk)相当于转了整整 ( j − k ) (j-k) (jk)圈,所以值为 1 1 1,这个等比数列的和为 0 0 0

由于当 j ≠ k j \neq k j̸=k时上述等比数列值为 0 0 0,所以 c k = a k n c_k=a_kn ck=akn,即 a k = c k n a_k=\frac{c_k}{n} ak=nck

至此,已经可以写出递归的FFT代码了。(常数大的一批qwq

实测洛谷3803有 77 77 77分,会TLE两个点。

下面放上部分代码。建议继续阅读之前先充分理解这种写法。

const int N = (1e6 + 10) * 4;
const double PI = 3.141592653589793238462643383279502884197169399375105820974944;
struct cpx
{
	double a, b;
	cpx(){}
	cpx(const double x, const double y = 0)
		: a(x), b(y){}
	cpx operator + (const cpx &c) const
	{
		return (cpx){a + c.a, b + c.b};
	}
	cpx operator - (const cpx &c) const
	{
		return (cpx){a - c.a, b - c.b};
	}
	cpx operator * (const cpx &c) const
	{
		return (cpx){a * c.a - b * c.b, a * c.b + b * c.a};
	}
};
int n, m;
cpx a[N], b[N], buf[N];
inline cpx omega(const int n, const int k)
{
	return (cpx){cos(2 * PI * k / n), sin(2 * PI * k / n)};
}
void FFT(cpx *a, const int n, const bool inv)
{
	if (n == 1)
		return;
	static cpx buf[N];
	int mid = n >> 1;
	for (int i = 0; i < mid; i++)
	{
		buf[i] = a[i << 1];
		buf[i + mid] = a[i << 1 | 1];
	}
	memcpy(a, buf, sizeof(cpx[n]));
	//now a[i] is coefficient
	FFT(a, mid, inv), FFT(a + mid, mid, inv);
	//now a[i] is point value
	//a[i] is A1(w_n^i), a[i + mid] is A2(w_n^i)
	for (int i = 0; i < mid; i++)
	{//calculate point value of A(w_n^i) and A(w_n^{i+n/2})
		cpx x = omega(n, i * (inv ? -1 : 1));
		buf[i] = a[i] + x * a[i + mid];
		buf[i + mid] = a[i] - x * a[i + mid];
	}
	memcpy(a, buf, sizeof(cpx[n]));
}
int work()
{
	read(n), read(m);
	for (int i = 0; i <= n; i++)
	{
		int tmp;
		read(tmp);
		a[i] = tmp;
	}
	for (int i = 0; i <= m; i++)
	{
		int tmp;
		read(tmp);
		b[i] = tmp;
	}
	for (m += n, n = 1; n <= m; n <<= 1);
	FFT(a, n, false), FFT(b, n, false);
	for (int i = 0; i < n; i++)
		a[i] = a[i] * b[i];
	FFT(a, n, true);
	for (int i = 0; i <= m; i++)
		write((int)((a[i].a / n) + 0.5)), putchar(' ');
	return 0;
}

五、优化

递归太慢了,我们用迭代。

考虑奇偶分组的过程。每一次把奇数项分到前面,偶数项分到后面,如 { a 0 , a 1 , a 2 , a 3 , a 4 , a 5 , a 6 , a 7 } \{a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7\} {a0,a1,a2,a3,a4,a5,a6,a7},按照这个过程分组,最终每组只剩一个数的时候是 { a 0 , a 4 , a 2 , a 6 , a 1 , a 5 , a 3 , a 7 } \{a_0,a_4,a_2,a_6,a_1,a_5,a_3,a_7\} {a0,a4,a2,a6,a1,a5,a3,a7}。经过仔mo细bai观da察lao,发现 1 ( 10 ) = 00 1 ( 2 ) 1_{(10)}=001_{(2)} 1(10)=001(2) 4 ( 10 ) = 10 0 ( 2 ) 4_{(10)}=100_{(2)} 4(10)=100(2),一个数最终变成的数的下标是它的下标的二进制表示颠倒过来(并不知道为什么)。我们可以递推算这个(其中lg2是 l o g 2 n log_2n log2n):

rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (lg2 - 1))

可以先生成原数组经过 l o g 2 n log_2n log2n次奇偶分组的最终状态,然后一层一层向上合并即可。

另外,标准库中的三角函数很慢,可以打出 w n k w_n^k wnk w n − k w_n^{-k} wnk的表(或者只打一个表,因为 w n − k = w n n − k w_n^{-k}=w_n^{n-k} wnk=wnnk)。当前分治的区间长度为 l l l时,查询 w l k w_l^k wlk相当于查询 w n n k l w_n^{\frac{nk}{l}} wnlnk(这里要小心 n k nk nk爆int……血的教训)。

代码如下(洛谷1919)

#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

namespace zyt
{
	template<typename T>
	inline void read(T &x)
	{
		char c;
		bool f = false;
		x = 0;
		do
			c = getchar();
		while (c != '-' && !isdigit(c));
		if (c == '-')
			f = true, c = getchar();
		do
			x = x * 10 + c - '0', c = getchar();
		while (isdigit(c));
		if (f)
			x = -x;
	}
	inline void read(char &c)
	{
		do
			c = getchar();
		while (!isgraph(c));
	}
	template<typename T>
	inline void write(T x)
	{
		static char buf[20];
		char *pos = buf;
		if (x < 0)
			putchar('-'), x = -x;
		do
			*pos++ = x % 10 + '0';
		while (x /= 10);
		while (pos > buf)
			putchar(*--pos);
	}
	const int N = (1 << 17) + 11;
	const double PI = acos(-1.0L);
	struct cpx
	{
		double a, b;
		cpx(const double x = 0, const double y = 0)
			:a(x), b(y) {}
		cpx operator + (const cpx &c) const
		{
			return (cpx){a + c.a, b + c.b};
		}
		cpx operator - (const cpx &c) const
		{
			return (cpx){a - c.a, b - c.b};
		}
		cpx operator * (const cpx &c) const
		{
			return (cpx){a * c.a - b * c.b, a * c.b + b * c.a};
		}
		cpx conj() const
		{
			return (cpx){a, -b};
		}
		~cpx(){}
	}omega[N], inv[N];
	int rev[N];
	void FFT(cpx *a, const int n, const cpx *w)
	{
		for (int i = 0; i < n; i++)
			if (i < rev[i])
				swap(a[i], a[rev[i]]);
		for (int len = 1; len < n; len <<= 1)
			for (int i = 0; i < n; i += (len << 1))
				for (int k = 0; k < len; k++)
				{
					cpx tmp = a[i + k] - w[k * (n / (len << 1))] * a[i + len + k];
					a[i + k] = a[i + k] + w[k * (n / (len << 1))] * a[i + len + k];
					a[i + len + k] = tmp;
				}
	}
	void init(const int lg2)
	{
		for (int i = 0; i < (1 << lg2); i++)
		{
			rev[i] = rev[i >> 1] >> 1 | (i & 1) << (lg2 - 1);
			omega[i] = (cpx){cos(2 * PI * i / (1 << lg2)), sin(2 * PI * i / (1 << lg2))};
			inv[i] = omega[i].conj();
		}
	}
	int work()
	{
		int n;
		static cpx a[N], b[N];
		read(n);
		for (int i = 0; i < n; i++)
		{
			char c;
			read(c);
			a[i] = c - '0';
		}
		for (int i = 0; i < n; i++)
		{
			char c;
			read(c);
			b[i] = c - '0';
		}
		for (int i = 0; (i << 1) < n; i++)
			swap(a[i], a[n - i - 1]), swap(b[i], b[n - i - 1]);
		int lg2 = 0, tmp = n << 1;
		for (n = 1; n < tmp; ++lg2, n <<= 1);
		init(lg2);
		FFT(a, n, omega), FFT(b, n, omega);
		for (int i = 0; i < n; i++)
			a[i] = a[i] * b[i];
		FFT(a, n, inv);
		bool st = false;
		static int ans[N];
		for (int i = 0; i < n; i++, n += (ans[n]))
		{
			ans[i] += (int)(a[i].a / n + 0.5);
			ans[i + 1] += ans[i] / 10;
			ans[i] %= 10;
		}
		for (int i = n - 1; i >= 0; i--)
			if (st || ans[i])
				write(ans[i]), st = true;
		return 0;
	}
}
int main()
{
	return zyt::work();
}

你可能感兴趣的:(数学)