快速傅里叶变换(FFT),常用于解答多项式乘法相关内容。
背景故事:
在我们平时计算多项式乘法的时候,我们把第一个多项式的每一项都和第二个多项式的每一项相乘,复杂度为O(n ^ 2),此时我们所使用的表示法就是系数表示法。
现在我们可以有一种比较强大的方式:
点值表示法:
众所周知,假设f(x)的最高次数为 n−1 ,即次数界为 n ,那么我们只要知道了n个不相同的x及f(x)值,就能确定出f(x)的多项式。
有一种算法叫做秦九韶算法,它能得出的结论是:
对于一个n次多项式,至多做n次乘法和n次加法。
所以我们知道了n个点的x值之后,我们在O(n^2)的时间内就能计算出所有的y值,然而如果通过快速傅里叶变换的方法,可以在O(nlogn)的时间内求出所有的y值。
引入新定义:
求值:通过多项式的系数表示法求其点值表示法。
插值:通过多项式的点值表示法求其系数表示法。
显然上面两个定义是互逆的关系,FFT就是用来解决求值的过程的。
引入新定义:
n次单位复根:在复平面内,n次单位复根能把整个平面分成n块。它的严格定义是:满足 ωn=1 的复数 ω 值,它一共有n个,分别为 ωkn(k=0...n−1) ,其数值为 e2πik/n。
由复数幂的定义,可知:
(ω1n)k=ωkn
它有很多性质:
1.相消引理:
接下来我们分析一下如何得到最终的多项式吧。
1.求A的n个单位根的点值,求B的n个单位根的点值
2.点值相乘,得到C的点值。
3.计算C的多项式。
我们先考虑第一步。
我们希望计算多项式:
也就是说,点值的结果就是DFT,那么我们只需要快速计算DFT的值即可。
如果按照正常的算法,时间复杂度显然是 O(n2) 的,所以有快速傅里叶变换(FFT),它采取的是分治的思想。
我们考虑原来的多项式A(x),定义两个新的次数界为n / 2的多项式:
把点值相乘。
显而易见的一点是,我们如果有两个多项式A(x),B(x),它们相乘得到多项式C(x),则:
接下来我们要做的就是插值了。[忘了定义的请自觉往回翻]
因为我不知道拉格朗日插值公式是什么鬼(划掉),所以我决定还是不讲了。
//别打我别打我……我讲…………
嗯我们考虑把DFT写成矩阵的形式:
bzoj2179,FFT模版,据说我这样不加注释会运行的更快。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define Rep(i,n) for(int i = 1;i <= n;i ++)
#define Rep_0(i,n) for(int i = 0;i < n;i ++)
#define PI M_PI
using namespace std;
struct Virt
{
double r, i;
Virt(double r = 0.0,double i = 0.0)
{
this->r = r;
this->i = i;
}
Virt operator + (const Virt &x)
{
return Virt(r + x.r, i + x.i);
}
Virt operator - (const Virt &x)
{
return Virt(r - x.r, i - x.i);
}
Virt operator * (const Virt &x)
{
return Virt(r * x.r - i * x.i, i * x.r + r * x.i);
}
};
const int M = 131072;
Virt a[M],b[M];
int n,m,BL,rev[M],c[M];
char ch[M >> 1];
void Rader()
{
for(n = 1;n <= m; n <<= 1,BL ++);
for (int i = 1; i < n; i ++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (BL - 1));
}
void FFT(Virt *a,int ty = 1)
{
Rep_0(i,n)if(i < rev[i])swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1)
{
Virt wn(cos(PI / i),ty * sin(PI / i));
for(int j = 0;j < n;j += (i << 1))
{
Virt w(1,0);
for(int k = 0;k < i;k ++,w = w * wn)
{
Virt u = a[j + k],v = w * a[j + k + i];
a[j + k] = u + v,a[j + k + i] = u - v;
}
}
}
if(ty == -1)Rep_0(i,n)a[i] = a[i].r / n;
}
int main ()
{
scanf("%d",&n);
scanf("%s",ch);
Rep_0(i,n)a[i] = ch[n - i - 1] - '0';
scanf("%s",ch);
Rep_0(i,n)b[i] = ch[n - i - 1] - '0';
n --;
m = n << 1;
Rader();
FFT(a),FFT(b);
for(int i = 0;i <= n;i ++)a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 0;i <= m;i ++)c[i] = (int)(a[i].r + 0.1);
for(int i = 0;i <= m;i ++)
{
if(c[i] >= 10)
{
c[i + 1] += c[i] / 10,c[i] %= 10;
if(i == m)m ++;
}
}
for(int i = m;i >= 0;i --)printf("%d",c[i]);
return 0;
}