昨天参悟了一天FFT,总算是理解了,今天的莫比乌斯反演也不太懂,干脆弃疗,决定来认真水一发博客。
什么是FFT?
FFT(Fast Fourier Transformation),即为快速傅氏变换,是离散傅氏变换(DFT)的快速算法,它是根据离散傅氏变换的奇、偶、虚、实等特性,对离散傅立叶变换的算法进行改进获得的。
FFT的作用?
主要用于加速多项式乘法(形如an x^n + a(n - 1) x^(n - 1) + …… + a1 x + a0),同时可以优化很多与多项式乘法相近的内容,比如高精度乘法(令x为10)。
先明确几个概念:
复数:
由两个部分组成,实数部分,虚数部分,形如 :a,ib(a为实数部分) 其中i^2 = -1,显然i不是一个实数。
复数的运算法则:
加法:实数部分相加,虚数部分相加
减法:实数部分相减,虚数部分相减
乘法:
我们来举一个例子:
(a,ib)* (c,id)
= ac + iad + ibc + i^2bd
= (ac - bd) + i(ad + bc)
=(ac - bd,i(ad + bc))
(i ^ 2 = -1)
我们考虑用坐标系来表示一下复数,
可以理解一下,对后文的一些讲解会有所帮助。(注意y轴的默认单位长度为i)
由坐标轴可以得出,复数(a,ib)的模长为sqrt(a^2 + b^2)
同理我么可以得出复数的乘法运算的直观体现,模长相乘,幅角相加。(自己可以带入两个(1,i1)计算来很好的理解)
多项式的系数表示与点值表示。
我们知道一个最高次项为n的多项式,有n + 1个系数,x^n……x^0的对应的系数。
如果我们将这n+1个系数构成一个n+1维的向量,显然可以唯一的确定出一个多项式。
那么这个向量就是系数表达式。
如果我们带入n个数字,求算出n个对应的值,那么这些值就构成了点值表达式。
我们同样可以认为这个点值表达式可以唯一确定出一个多项式。
证明如下:(转自Menci,鸣谢作者,链接详见左侧友情链接)
证明:假设命题不成立,存在两个不同的 n−1 次多项式 A(x)、B(x),满足对于任何 i∈[0, n−1],有 A(xi)=B(xi)。
令 C(x)=A(x)−B(x),则 C(x) 也是一个 n−1 次多项式。对于任何 i∈[0, n−1],有 C(xi)=0。
即 C(x) 有 n 个根,这与代数基本定理(一个 n−1 次多项式在复数域上有且仅有 n−1 个根)相矛盾,故 C(x) 并不是一个 n−1 次多项式,原命题成立,证毕。
插值:已知点值表达,求系数表达式
单位根:我们上文提及虚数可以在坐标系内表示。我们可以在坐标系内做半径为1的圆,作为单位元,如果把单位圆分成n分,那么最靠近x轴正半轴的一份的考上的边即为wn = w0,即为单位根,剩下的依次为w1,w2,w3……wn-1;
其中单位根的幅角2π/n ,由欧拉公式可以得出cos2k2n2π+isin2k2n2π=coskn2π+isinkn2π
我们在求解点值表达式时,通常带入单位根,举个例子
如果有n项,那么我们可以分别带入wn^0,wn^2,wn^(n-1),这样子便于计算,此结论是前人证明,在此不详细叙述。
同时我们随手得出几个小的结论。
没有什么比画图更能说明这个两个结论了。(结合上文提及的复数乘法)
折半定理:
ω2n2k=ωnk
ωnk+2n=−ωnk
好,我们开始步入正题,使用FFT进行多项式乘法在nlogn的时间内进行运算。
我们简述一下FFT的流程,先将这两个多项式转换为点值表达式,然后在线性将两个多项式每一位相乘,然后将得出的新的点值表达式转换会系数表达式输出即可。
我们闲来考虑第一部分,将系数表达式转化为点值表达式。
我们先明确一下,我们一下所指的所有多项式,最高次项均为2^k - 1。如果不足,请默认在高位补零。
通常将系数表达式转化为点值表达式有两种方法,递归与迭代,递归由于传参可能涉及到数组,所以通常效率稍微差些,而迭代版则不存在这个问题,但是递归更便于理解,所以我们从递归说起。
我们定义一个函数DFT(vector<复数>) vector内存着每一位的系数,可以将系数表达式转化为点值表达式
首先我们先明确边界,如果只有一个数,那么系数表达式就是点值表达式,直接返回即可。
如果没有到边界,我们进行如下操作。
先将每个元素分别按照下标的奇偶处理,分别递归操作。
这时候我们举个栗子来观察一下
a3 (w4) ^ 3 + a2(w4)^2 + a1(w4)^1 + a0
按奇偶分成两个递归
a3(w2) ^ 1 + a1 a2(w2) ^ 1 + a0
我们可以根据之前的折半定理
将递归出的左式转化为,a3(w4) ^ 2 + a1(w4),我么发现这个式子×w4即为上文的4项式子的奇数部分的值。
我们递归出的右式,通过折半定理,可以转化成 a2(w4)^2 + a0 即为4项式的偶数项的和。
我们上文讲的式对于带入一个值求出的点值表达式,而我们的DFT是返回带入n个单位根,每个答案分别存在vector中一位。
所以我们要宏观的再来考虑一下,我们用f(i)表示带入i,当前多项式得到的结果。
我们来宏观的考虑一下。
f(w0) ,f(w1),f(w2),f(w3) 对应的多项式 a3(wx)^3 + a2(wx) ^ 2 + a1(wx) ^ 1 + a0
f(w1),f(w0) 对应的多项式 a3(w x/2)^1 + a1 f(w1),f(w0) 对应的多项式 a2(w x/2)^1 + a0
根据我们上文的推理,我们可以得出
f(w0) a3(w0)^3 + a2(w0) ^ 2 + a1(w0) ^ 1 + a0
=w0(a3(w0) ^ 1 + a1) + a2(w0) ^ 1 + a0
= 左f(w0) * w0 + 右f(w0)
同理我们只需要将w1……wn-1同样处理,只是每次不是在 左×w0而是(w1……wn-1)即可。
这样子我认为已经很详细的写出了如何在递归中求出点值表达式了。
我们给出代码:
vector DFT(vector a)
{
if (a.size() == 1) return a;
vector a1,a0,y1,y0,ans;
for (int i = 0;i < a.size();i++)
{
if (i & 1)
a1.push_back(a[i]);else
a0.push_back(a[i]);
}
y0 = DFT(a0,pd);
y1 = DFT(a1,pd);
pot wn;
wn = pot(cos(2.0f * PI /(double) a.size()),sin(2.0f * PI / (double) a.size()));
pot w = pot(1.0,0.0);
ans.resize(a.size());
for (int i = 0;i < a.size() / 2;i++)
{
ans[i] = y0[i] + y1[i] * w;
ans[i + (a.size() >> 1)] = y0[i] - y1[i] * w;
w = w * wn;
}
return ans;
}
后来有些小伙伴私信我问上述代码的第二个for循环内的计算对称部分的为什么在y1前加了负号,这个问题之前忘记书写了,这里解释一下。
我们考虑ak((wn) ^ p)^k 对应的在右侧部分的为ak((wn) ^ (p + n/2))^k
然后我们上文有提及两个很基本的小性质,其中第二个可以把 (wn) ^ (p + n/2) 转换为-wn^p,而当我们位于右侧部分,也就是所谓的奇数部分,最外侧的^k不为偶数,算出的值为负,所以需要在右侧加上-号
有人问wn为什么那么算,就是在复数坐标轴上很简单的几何意义,可以自己画一下。
递归如果已经理解,那么迭代就非常容易理解了,在此给出代码,基本思路跟递归是相同的,只不过我们通过一个for循环来枚举长度而已,但注意此时我们发现迭代中缺少了递归中奇偶分类,但是非常幸运,我们是可以预先推算吃迭代的处理顺序,从而提前处理好奇偶数的位置关系 ,这里给出基于二分的nlogn处理方式,这里非常显然,不做任何讲解。然而有一种更加优美的写法,通过二进制的奇妙操作,在常数较短的时间内进行处理奇数偶数。
暴力nlogn写法:
void pre(int l, int r)
{
if (l < r)
{
int mid = (l + r) >> 1;
static pot dl[500010], dr[500010];
for (int i = l; i < r; i += 2)
{
dl[(i - l) >> 1] = dla[i];
dr[(i - l) >> 1] = dla[i + 1];
}
memcpy(dla + l, dl, (mid - l + 1) * sizeof(dla[0]));
memcpy(dla + mid + 1, dr, (r - mid) * sizeof(dla[0]));
for (int i = l; i < r; i += 2)
{
dl[(i - l) >> 1] = dlb[i];
dr[(i - l) >> 1] = dlb[i + 1];
}
memcpy(dlb + l, dl, (mid - l + 1) * sizeof(dlb[0]));
memcpy(dlb + mid + 1, dr, (r - mid) * sizeof(dlb[0]));
pre(l, mid);
pre(mid + 1, r);
}
}
我们可以用一个
000 001 010 011 100 101 110 111 0 1 2 3 4 5 6 7 0 2 4 6 - 1 3 5 7 0 4 - 2 6 - 1 5 - 3 7 0 - 4 - 2 - 6 - 1 - 5 - 3 - 7 000 100 010 110 001 101 011 111(本段演示及相关二进制转化代码均转自Menci,鸣谢作者,链接详见友情链接)
int k = 0; while ((1 << k) < n) k++; for (int i = 0; i < n; i++) { int t = 0; for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1)); if (i < t) std::swap(a[i], a[t]); }
void DFT()
{
for (int mi = 2;mi <= n;mi <<= 1)
{
pot wn;
wn = pot(cos(PI * 2 / (double) mi),-sin(PI * 2 / (double) mi));
for (int j = 0;j < n;j += mi)
{
int midn = j + (mi >> 1);
pot w = pot(1.0f,0);
for (int k = j;k < midn;k++)
{
pot tp = dla[k];
dla[k] = dla[k] + w * dla[k + (mi >> 1)];
dla[k + (mi >> 1)] = tp - w * dla[k + (mi >> 1)];
tp = dlb[k];
dlb[k] = dlb[k] + w * dlb[k + (mi >> 1)];
dlb[k + (mi >> 1)] = tp - w * dlb[k + (mi >> 1)];
w = w * wn;
}
}
}
}
我们来考虑一下,如何将点值表达式转化为系数表达式。
我们把从系数表达式求成点值表达式的过程抽象为矩阵乘法
A矩阵 (wn的幂<0时)D矩阵 否则为V矩阵,即D是V的逆矩阵 F矩阵
以下为具体的矩阵推导过程,(鸣谢xys在此的帮助,其github详见友情链接)
F=V×A(显而易见)
E=D×V(由一些奇妙的定理可得,E为长度为n的单位矩阵,即对角线为n,其余区域为0)`
I=1/nE (I为单位长度是1的单位矩阵)
=1/nD × V
1/n D = V^(-1) 与F=V×A连理得
1/n DF = A
我们回头来看,这tmd不就是DFT的逆过程么,只需要在前面加一个1/n即可。
所以我们只需要在DFT内加一个小的改动,并将结果进行一个小处理即可。
具体改动是将
wn = pot(cos(2.0f * PI /(double) a.size()),sin(2.0f * PI / (double) a.size()));
在sin前面加上一个负号,并且在操作完全结束后,将ans数组/n即可。
至于为什么这么做,因为实际上,我们在操作中不会单独搞出一个D矩阵,而是继续将就着用V矩阵来节约代码量,所以我们需要将wn进行修改。而ans/n是因为我们
1/n DF = A 前有一个n/1,最后上传一下二合一的DFT与IDFT,传参为0或任意非零值。
迭代版本:
void DFT(int pd)
{
for (int mi = 2;mi <= n;mi <<= 1)
{
pot wn;
if (!pd)
wn = pot(cos(PI * 2 / (double) mi),sin(PI * 2 / (double) mi));else
wn = pot(cos(PI * 2 / (double) mi),-sin(PI * 2 / (double) mi));
for (int j = 0;j < n;j += mi)
{
int midn = j + (mi >> 1);
pot w = pot(1.0f,0);
for (int k = j;k < midn;k++)
{
pot tp = dla[k];
dla[k] = dla[k] + w * dla[k + (mi >> 1)];
dla[k + (mi >> 1)] = tp - w * dla[k + (mi >> 1)];
tp = dlb[k];
dlb[k] = dlb[k] + w * dlb[k + (mi >> 1)];
dlb[k + (mi >> 1)] = tp - w * dlb[k + (mi >> 1)];
w = w * wn;
}
}
}
}
vector DFT(vector a,int pd)
{//printf("DFT : %d\n", a.size());
if (a.size() == 1) return a;
vector a1,a0,y1,y0,ans;
for (int i = 0;i < a.size();i++)
{
if (i & 1)
a1.push_back(a[i]);else
a0.push_back(a[i]);
}
y0 = DFT(a0,pd);
y1 = DFT(a1,pd);
pot wn;
if (pd == 0) wn = pot(cos(2.0f * PI /(double) a.size()),-sin(2.0f * PI / (double) a.size()));else
wn = pot(cos(2.0f * PI /(double) a.size()),sin(2.0f * PI / (double) a.size()));
pot w = pot(1.0,0.0);
ans.resize(a.size());
for (int i = 0;i < a.size() / 2;i++)
{
ans[i] = y0[i] + y1[i] * w;
ans[i + (a.size() >> 1)] = y0[i] - y1[i] * w;
w = w * wn;
}
return ans;
}
int l = la + lb + 2;
int k = 0;
while (l > 0)
{
k++;
l >>= 1;
}
n = 1 << k;
第一次写技术含量这么高的博客,如果有什么不周到的,大家可以提出,或私信我,有人问上面出现的wn和w0的关系,我这默认wn== w0,酱紫。
看完全文的小伙伴辛苦了,有兴趣可以点点友链之类的。= ̄ω ̄=