DFT(Discrete Fourier Transform,离散傅里叶变换)是傅里叶变换在时域和频域上都呈离散的形式,而FFT,则是将DFT快速实现的一种方式,在计算机系统、数字系统中有重要作用,显著降低了运算的复杂性。
卷积 (Convolution),说是一种通过两个函数(f · g) 生成第三个函数的一种数学算子.
更具一般性,卷积定义为:
h ( x ) = ∫ − ∞ ∞ g ( τ ) f ( x − τ ) d τ h(x) = \int_{-\infty}^{\infty}g(\tau)f(x - \tau)d\tau h(x)=∫−∞∞g(τ)f(x−τ)dτ
形如这个式子,由 g ( τ ) g(\tau) g(τ) 与 f ( x − τ ) f(x - \tau) f(x−τ) 形成的 h ( x ) h(x) h(x) 就被称为 卷积
这个定义可以拓展到多项式域:
A ( x ) ⋅ B ( x ) = ∑ i = 0 n ∑ j = 0 i a j b i − j A(x) \cdot B(x) = \sum_{i= 0}^n \sum_{j = 0}^i a_j b_{i - j} A(x)⋅B(x)=i=0∑nj=0∑iajbi−j
其中, A ( x ) A(x) A(x) 和 B ( x ) B(x) B(x) 均为 n n n 次多项式
对这个多项式域的卷积合并同类项,最后可以得到 2 n + 1 2n + 1 2n+1 项
点表示法(Dot Method),是另外一种描述多项式的形式,更一般的,我们遇到的多项式都是采用 系数表示法,也就是下面的形式:
f ( x ) = a 0 + a 1 x + a 2 x 2 + a 3 x 3 + . . . + a n x n ⇔ f ( x ) = { a 0 , a 1 , a 2 . . . a n } f(x) = a_0 + a_1x + a_2x^2 + a_3x^3 + ... + a_nx^n \Leftrightarrow f(x) = \{{a_0,a_1,a_2... a_n}\} f(x)=a0+a1x+a2x2+a3x3+...+anxn⇔f(x)={a0,a1,a2...an}
而对于点值表示法来说,我们就不是选取多项式的系数来表示这个多项式,而是在这个多项式所描述的图像(曲线)上选择一些点,来描述这个多项式。
其思想为:对于一个 n n n 次多项式,我们选择 n + 1 n + 1 n+1 个点,只要这些点不重复(指点互相不呈现倍数关系),那么这 n + 1 n + 1 n+1 个点就能唯一确定一个多项式。
借助高斯消元的思想:两点确定直线,每多一个点,就可以多确定曲线的一个参数,那么 n + 1 n + 1 n+1 个点就能确定 n n n 个参数。
因此,我们将多项式视为一个函数,采用点值表示法描述如下:
f ( x k ) ⇔ a 0 + a 1 x k + a 2 x k 2 + . . . + a n x k n , 0 ≤ k ≤ n f ( x ) = a 0 + a 1 x + a 2 x 2 + . . . + a n x n ⇔ f ( x ) = { ( x 0 , y 0 ) , ( x 1 , y 1 ) . . . ( x n , y n ) } f(x_k) \Leftrightarrow a_0 + a_1x_k + a_2x_k^2 + ... + a_nx_k^n,0 \leq k \leq n \\ f(x) = a_0 + a_1x + a_2x^2 + ... + a_nx^n \Leftrightarrow f(x) = \{(x_0,y_0),(x_1,y_1)...(x_n,y_n)\} f(xk)⇔a0+a1xk+a2xk2+...+anxkn,0≤k≤nf(x)=a0+a1x+a2x2+...+anxn⇔f(x)={(x0,y0),(x1,y1)...(xn,yn)}
将其转换为多项式描述法,即设 F ( x ) = f ( x ) ⋅ g ( x ) F(x) = f(x) \cdot g(x) F(x)=f(x)⋅g(x)就有:
f ( x ) = { ( x 0 , f ( x 0 ) ) , ( x 1 , f ( x 1 ) ) . . . ( x n , f ( x n ) ) } g ( x ) = { ( x 0 , g ( x 0 ) ) , ( x 1 , g ( x 1 ) ) . . . ( x n , g ( x n ) ) } F ( x ) = { ( x 0 , f ( x 0 ) g ( x 0 ) ) , ( x 1 , f ( x 1 ) g ( x 1 ) ) . . . ( x n , f ( x n ) g ( x n ) ) } f(x) = \{(x_0,f(x_0)),(x_1,f(x_1))...(x_n,f(x_n))\} \\ g(x) = \{(x_0,g(x_0)),(x_1,g(x_1))...(x_n,g(x_n))\} \\ F(x) = \{(x_0,f(x_0)g(x_0)),(x_1,f(x_1)g(x_1))...(x_n,f(x_n)g(x_n))\} \\ f(x)={(x0,f(x0)),(x1,f(x1))...(xn,f(xn))}g(x)={(x0,g(x0)),(x1,g(x1))...(xn,g(xn))}F(x)={(x0,f(x0)g(x0)),(x1,f(x1)g(x1))...(xn,f(xn)g(xn))}
经过以上过程,我们就成功将一个多项式,从系数表达式转变为了点表达式,这一过程又称 D F T DFT DFT
接下来让我们回忆一下到目前为止我们对多项式的讨论:
我们把一个 n n n 阶的多项式视为一个函数,在其中选取了 n + 1 n + 1 n+1 个点,来描述这个多项式。
考虑一个朴素的做法:
—选取 n + 1 n + 1 n+1 个点,我们需要 Θ ( n 2 ) \Theta(n^2) Θ(n2) 时间
—点乘法,对于 n n n 阶多项式而言,我们需要 Θ ( n ) \Theta(n) Θ(n)
—将点值映射回系数表达式,至少需要 Θ ( n 2 ) \Theta(n^2) Θ(n2) 的时间
看起来更麻烦了,比朴素的乘法常数还要高,难道真的没有办法了吗?
这时候就是 F F T FFT FFT 出场的时候了, F F T FFT FFT 可以使得对于第一个选点和第三个将点映射回系数的操作从 Θ ( n 2 ) \Theta(n^2) Θ(n2) 变为 Θ ( n l o g n ) \Theta(nlogn) Θ(nlogn)
本质上, FFT 包含 DFT (离散傅立叶变换)和 IDFT (逆离散傅立叶变换)实际上, DFT 对应着的就是把系数表达式映射到点值表达式的过程, IDFT 对应着的就是我们把点值表达式映射到系数表达式的过程。
下面我们来讨论这个优化的过程。
首先我们假设有一个复数 w w w 其满足 w n = 1 w^n = 1 wn=1,易知其解 w w w 有 n n n 个,并且这 n n n 个解都满足 w w w 在复数平面上的模长都是 1 。根据复数平面的知识,两个复数相乘,等于其在复数平面上的向量模长相乘,幅角相加,因此这 n n n 个解我们可以理解为复数平面的 n n n 个单位向量,并且他们的幅角之和为360°,也就是把复数平面内的单位元 n n n 等分的 n n n 个向量。
考虑欧拉公式
e i x = c o s x + i s i n x e^{ix} = cosx + isinx eix=cosx+isinx
当x = 2π时,得:
e 2 π i = 1 = w n ⇔ w = e 2 π i n e^{2πi} = 1 = w^n \Leftrightarrow w = e^{\frac{2πi}{n}} e2πi=1=wn⇔w=en2πi
由此我们得到了这 n n n 个解的表示,其称为 主次单位根,记为:
w n = e 2 π i n w_n = e^{\frac{2πi}{n}} wn=en2πi
单位根具有如下性质:
w n n = 1 w d n d k = w n k , n ≥ 0 , k ≥ 0 , d > 0 ( w n k ) 2 = ( w n 2 k ) = ( w n 2 k ) w 2 n k + n = − w 2 n k w_{n}^{n} = 1 \\ w_{dn}^{dk} = w_{n}^{k},n \geq 0,k \geq 0,d > 0 \\ (w_{n}^{k})^2 = (w_n^{2k}) = (w_\frac {n}{2}^k) \\ w_{2n}^{k + n} = -w_{2n}^{k} wnn=1wdndk=wnk,n≥0,k≥0,d>0(wnk)2=(wn2k)=(w2nk)w2nk+n=−w2nk
介绍完单位复根的性质后,我们可以进入 F F T FFT FFT的环节了。
FFT采用了分治法来求取每一个 x = w n k x = w_n^k x=wnk 时的值,分治的关键在于将多项式化为奇函数与偶函数进行处理。
下面是一个对7阶多项式分治的例子:
f ( x ) = a 0 + a 1 x + a 2 x 2 + a 3 x 3 + a 4 x 4 + a 5 x 5 + a 6 x 6 + a 7 x 7 f ( x ) = ( a 0 + a 2 x 2 + a 4 x 4 + a 6 x 6 ) + ( a 1 x 1 + a 3 x 3 + a 5 x 5 + a 7 x 7 ) f ( x ) = ( a 0 + a 2 x 2 + a 4 x 4 + a 6 x 6 ) + x ( a 1 + a 3 x 2 + a 5 x 4 + a 7 x 6 ) f(x) = a_0 + a_1x + a_2x^2 + a_3x^3 + a_4x^4 + a_5x^5 + a_6x^6 + a_7x^7 \\ f(x) = (a_0 + a_2x^2 + a_4x^4 + a_6x^6) + (a_1x^1 + a_3x^3 + a_5x^5 + a_7x^7) \\ f(x) = (a_0 + a_2x^2 + a_4x^4 + a_6x^6) + x(a_1 + a_3x^2 + a_5x^4 + a_7x^6) \\ f(x)=a0+a1x+a2x2+a3x3+a4x4+a5x5+a6x6+a7x7f(x)=(a0+a2x2+a4x4+a6x6)+(a1x1+a3x3+a5x5+a7x7)f(x)=(a0+a2x2+a4x4+a6x6)+x(a1+a3x2+a5x4+a7x6)
接下来构造两个新函数:
g ( x ) = ( a 0 + a 2 x + a 4 x 2 + a 6 x 3 ) h ( x ) = ( a 1 + a 3 x + a 5 x 2 + a 7 x 3 ) f ( x ) = g ( x 2 ) + x ⋅ h ( x 2 ) g(x) = (a_0 + a_2x + a_4x^2 + a_6x^3) \\ h(x) = (a_1 + a_3x + a_5x^2 + a_7x^3) \\ f(x) = g(x^2) + x \cdot h(x^2) g(x)=(a0+a2x+a4x2+a6x3)h(x)=(a1+a3x+a5x2+a7x3)f(x)=g(x2)+x⋅h(x2)
由于每一个x都是单位复根,可以带入单位复根的性质:
f ( w n k ) = g ( ( w n k ) 2 ) + w n k ⋅ h ( ( w n k ) 2 ) = g ( w n 2 k ) + w n k ⋅ h ( w n 2 k ) = g ( w n 2 k ) + w n k ⋅ h ( w n 2 k ) \begin{aligned} f(w_n^k) =& g((w_n^k)^2) + w_n^k \cdot h((w_n^k)^2) \\ =& g(w_{n}^{2k}) + w_n^k \cdot h(w_{n}^{2k}) \\ =& g(w_{\frac {n}{2}}^{k}) + w_n^k \cdot h(w_{\frac {n}{2}}^{k}) \end{aligned} f(wnk)===g((wnk)2)+wnk⋅h((wnk)2)g(wn2k)+wnk⋅h(wn2k)g(w2nk)+wnk⋅h(w2nk)
类似有:
f ( w n k + n 2 ) = g ( w n 2 k + n ) + w n k + n 2 ⋅ h ( w n 2 k + n ) = g ( w n 2 k ) − w n k ⋅ h ( w n 2 k ) = g ( w n 2 k ) − w n k ⋅ h ( w n 2 k ) \begin{aligned} f(w_n^{k + \frac{n}{2}}) =& g(w_n^{2k + n}) + w_n^{k + \frac{n}{2}} \cdot h(w_n^{2k + n}) \\ =& g(w_n^{2k}) - w_n^{k} \cdot h(w_n^{2k}) \\ =& g(w_{\frac {n}{2}}^{k}) - w_n^{k} \cdot h(w_{\frac {n}{2}}^{k}) \end{aligned} f(wnk+2n)===g(wn2k+n)+wnk+2n⋅h(wn2k+n)g(wn2k)−wnk⋅h(wn2k)g(w2nk)−wnk⋅h(w2nk)
因此我们只要求得 g ( w n 2 k ) g(w_{\frac {n}{2}}^{k}) g(w2nk) 和 h ( w n 2 k ) h(w_{\frac {n}{2}}^{k}) h(w2nk) 就可以求出 f ( w n k ) f(w_n^k) f(wnk) 和 f ( w n k + n 2 ) f(w_n^{k + \frac{n}{2}}) f(wnk+2n) 也就是可以递归求解。
需要注意分治处理的时候的多项式长度必须是2的整数次幂,因此需要补0,最高次项为 2 m − 1 2^{m-1} 2m−1的形式。
由此可以总结出DFT变换的代码模板:
#include
#include
typedef std::complex<double> complex;
const double PI = acos(-1.0);
const int MAX_N = 1 << 20;
complex tmp[MAX_N];
void DFT(complex* f, int n, int rev) { // rev=1,DFT; rev=-1,IDFT
if (n == 1) return;
for (int i = 0; i < n; ++i) tmp[i] = f[i];
for (int i = 0; i < n; ++i) { // 偶数放左边,奇数放右边
if (i & 1)
f[n / 2 + i / 2] = tmp[i];
else
f[i / 2] = tmp[i];
}
complex* g = f, * h = f + n / 2;
DFT(g, n / 2, rev), DFT(h, n / 2, rev); // 递归 DFT
complex cur(1, 0), step(cos(2 * PI / n), sin(2 * PI * rev / n));
// Comp step=exp(I*(2*PI/n*rev)); // 两个 step 定义是等价的
for (int k = 0; k < n / 2; ++k) {
tmp[k] = g[k] + cur * h[k];
tmp[k + n / 2] = g[k] - cur * h[k];
cur *= step;
}
for (int i = 0; i < n; ++i) f[i] = tmp[i];
}
可以看到我们在上面的DFT变换中多次对当前的多项式的奇偶次幂进行划分,而这个过程是可以被预处理的,而递归的过程也需要非常大的内存,所以我们采取了一种"模仿递归"的操作,对原数组进行“拆分”与“合并”。
我们采取从小到大循环的方式,当我们循环到 R ( x ) R(x) R(x) 时, R ( ⌊ x 2 ⌋ ) R(\lfloor \frac{x}{2} \rfloor) R(⌊2x⌋) 是已知的,观察后可以发现,除了最高位, R ( x ) R(x) R(x) 剩余的位数都是由 R ( ⌊ x 2 ⌋ ) R(\lfloor \frac{x}{2} \rfloor) R(⌊2x⌋) 右移一位得来。
对于最高位,我们只要检查 x x x 本身是否为 1 1 1 ,如果是,对 R ( x ) R(x) R(x) 加上 l e n 2 \frac{len}{2} 2len 即可。
即:
R ( x ) = ⌊ R ( ⌊ x 2 ⌋ ) ⌋ + ( x m o d 2 ) ⋅ l e n 2 R(x) = \lfloor R(\lfloor \frac{x}{2} \rfloor) \rfloor + (x \ mod \ 2) \cdot \frac{len}{2} R(x)=⌊R(⌊2x⌋)⌋+(x mod 2)⋅2len
代码实现如下:
// 同样需要保证 len 是 2 的幂
// 记 rev[i] 为 i 翻转后的值
void change(Complex y[], int len) {
for (int i = 0; i < len; i++) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) { // 如果最后一位是 1,则翻转成 len/2
rev[i] |= len >> 1;
}
}
for (int i = 0; i < len; i++) {
if (i < rev[i]) { // 保证每对数只翻转一次
swap(y[i], y[rev[i]]);
}
}
return;
}
前面介绍了DFT,其作用为将系数表示法转换为点表示法,现在我们有了点,还需要将其转换为系数,这就是IDFT。
经过了前面一长串的变化,我们求出了这个线性方程组左边的值,现在我们采用线性方程组的思想来考虑这个问题,给了我们方程组的向量,以及系数矩阵,如何求出方程组的解?很简单,只要我们左乘矩阵的逆即可。
观察这个矩阵,不难发现这是一个范德蒙方阵,在数学上可以证明其逆矩阵非常特殊,其每一项元素都是原矩阵的倒数再乘以 1 n \frac{1}{n} n1
此时我们结合欧拉公式,有一个惊人的发现:
1 w n = w n − 1 = e − 2 π i n = c o s ( 2 π n ) + i s i n ( − 2 π n ) \frac{1}{w_n} = w_n^{-1} = e^{-\frac{2πi}{n}} = cos(\frac{2π}{n}) + isin(-\frac{2π}{n}) wn1=wn−1=e−n2πi=cos(n2π)+isin(−n2π)
也就是说,我们只要再跑一遍对DFT的算法,只不过多乘一个 − 1 -1 −1 的系数,就可以将我们得到的点再转变为系数。
如下是模仿递归的代码:
/*
* 做 FFT
* len 必须是 2^k 形式
* on == 1 时是 DFT,on == -1 时是 IDFT
*/
void fft(Complex y[], int len, int on) {
change(y, len);
for (int h = 2; h <= len; h <<= 1) { // 模拟合并过程
Complex wn(cos(2 * PI / h), sin(on * 2 * PI / h)); // 计算当前单位复根
for (int j = 0; j < len; j += h) {
Complex w(1, 0); // 计算当前单位复根
for (int k = j; k < j + h / 2; k++) {
Complex u = y[k];
Complex t = w * y[k + h / 2];
y[k] = u + t; // 这就是把两部分分治的结果加起来
y[k + h / 2] = u - t;
// 后半个 “step” 中的ω一定和 “前半个” 中的成相反数
// “红圈”上的点转一整圈“转回来”,转半圈正好转成相反数
// 一个数相反数的平方与这个数自身的平方相等
w = w * wn;
}
}
}
if (on == -1) {
for (int i = 0; i < len; i++) {
y[i].x /= len;
}
}
}
完整实现:
#include
#include
#include
using namespace std;
struct complex {
double x, y;
complex(double _x = 0.0, double _y = 0.0) {
x = _x;
y = _y;
}
complex operator-(const complex& b) const {
return complex(x - b.x, y - b.y);
}
complex operator+(const complex& b) const {
return complex(x + b.x, y + b.y);
}
complex operator*(const complex& b) const {
return complex(x * b.x - y * b.y, x * b.y + y * b.x);
}
};
const double PI = acos(-1.0);
const int MAXN = 200020;
complex x1[MAXN], x2[MAXN];
char str1[MAXN / 2], str2[MAXN / 2];
int sum[MAXN],rev[MAXN];
// 同样需要保证 len 是 2 的幂
// 记 rev[i] 为 i 翻转后的值
void change(complex y[], int len) {
for (int i = 0; i < len; i++) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) { // 如果最后一位是 1,则翻转成 len/2
rev[i] |= len >> 1;
}
}
for (int i = 0; i < len; i++) {
if (i < rev[i]) { // 保证每对数只翻转一次
swap(y[i], y[rev[i]]);
}
}
return;
}
/*
* 做 FFT
* len 必须是 2^k 形式
* on == 1 时是 DFT,on == -1 时是 IDFT
*/
void fft(complex y[], int len, int on) {
change(y, len);
for (int h = 2; h <= len; h <<= 1) { // 模拟合并过程
complex wn(cos(2 * PI / h), sin(on * 2 * PI / h)); // 计算当前单位复根
for (int j = 0; j < len; j += h) {
complex w(1, 0); // 计算当前单位复根
for (int k = j; k < j + h / 2; k++) {
complex u = y[k];
complex t = w * y[k + h / 2];
y[k] = u + t; // 这就是把两部分分治的结果加起来
y[k + h / 2] = u - t;
// 后半个 “step” 中的ω一定和 “前半个” 中的成相反数
// “红圈”上的点转一整圈“转回来”,转半圈正好转成相反数
// 一个数相反数的平方与这个数自身的平方相等
w = w * wn;
}
}
}
if (on == -1) {
for (int i = 0; i < len; i++) {
y[i].x /= len;
}
}
}
int main() {
while (scanf("%s%s", str1, str2) == 2) {
int len1 = strlen(str1);
int len2 = strlen(str2);
int len = 1;
while (len < len1 * 2 || len < len2 * 2) len *= 2;
for (int i = 0; i < len1; i++) x1[i] = complex(str1[len1 - 1 - i] - '0', 0);
for (int i = len1; i < len; i++) x1[i] = complex(0, 0);
for (int i = 0; i < len2; i++) x2[i] = complex(str2[len2 - 1 - i] - '0', 0);
for (int i = len2; i < len; i++) x2[i] = complex(0, 0);
fft(x1, len, 1);
fft(x2, len, 1);
for (int i = 0; i < len; i++) x1[i] = x1[i] * x2[i];
fft(x1, len, -1);
for (int i = 0; i < len; i++) sum[i] = int(x1[i].x + 0.5);
for (int i = 0; i < len; i++) {
sum[i + 1] += sum[i] / 10;
sum[i] %= 10;
}
len = len1 + len2 - 1;
while (sum[len] == 0 && len > 0) len--;
for (int i = len; i >= 0; i--) printf("%c", sum[i] + '0');
printf("\n");
}
return 0;
}
代码部分参考各个大佬的算法模板,这里主要记录第一次学习时的思路。