自己实现用FFT加速多项式计算

/*
 * FFT之所以能加速DFT的计算,得益于n次单位复数根的几个性质:
 * 1.消去引理  w(d*n,d*k)=w(n,k);
 * 1.折半引理  w(n,k)^2=w(n/2,k);
 * 2.求和引理,即:sum{w(n,k)^j}=0 (0<=j<=n-1);

* 详见:算法导论第30章,p527~p535

* 不过我的这个递归实现运行速度太慢了,我测试了一下,运行时间是这种写法的3~4倍,所以在比赛的时候应尽量选择速度快的实现方式。

*/


#include<stdio.h>
#include<string.h>
#include<stdlib.h>
#include<math.h>
#include<vector>
#include<algorithm>

using namespace std;

#define N 150001

const double pi=acos(-1.0);

struct Complex{
    double real,img;
    Complex(double x=0,double y=0):real(x),img(y){}
};

Complex operator+(Complex A,Complex B){ return Complex(A.real+B.real,A.img+B.img); }
Complex operator-(Complex A,Complex B){ return Complex(A.real-B.real,A.img-B.img); }
Complex operator*(Complex A,Complex B){ return Complex(A.real*B.real-A.img*B.img,A.real*B.img+A.img*B.real); }

Complex *A[20][2];
Complex *Y[20][2];

void init(int n){ //预开空间
    int i;
    for(i=0;i<20&&n;i++){
        A[i][0]=new Complex[n];
        A[i][1]=new Complex[n];
        Y[i][0]=new Complex[n];
        Y[i][1]=new Complex[n];
        n>>=1;
    }
}
void FFT(int n,int flag,int dep,int t) //算法导论上面的伪代码c++实现
{
    int i;
    if(n==1){
        Y[dep][flag][0]=A[dep][flag][0];
        return ;
    }
    Complex W(1,0),O;
    if(t)O=Complex(cos(2*pi/n),sin(2*pi/n));  //主n次单位根
    else O=Complex(cos(-2*pi/n),sin(-2*pi/n));//逆主n次单位根,用于计算逆DFT
    for(i=0;i<n/2;i++) A[dep+1][0][i]=A[dep][flag][i*2];
    for(i=0;i<n/2;i++) A[dep+1][1][i]=A[dep][flag][i*2+1];
    FFT(n>>1,0,dep+1,t);  //递归计算Y0
    FFT(n>>1,1,dep+1,t);  //递归计算Y1
    for(i=0;i<n/2;i++){   //计算当前的DFT
        Y[dep][flag][i]=Y[dep+1][0][i]+W*Y[dep+1][1][i];
        Y[dep][flag][i+n/2]=Y[dep+1][0][i]-W*Y[dep+1][1][i];
        W=W*O;
    }
}
void cal(char *a,char *b,char *c){  //实现大整数a*b=c
    int i,n,m,L;
    n=strlen(a);m=strlen(b);L=max(n,m);
    //得到一个大于等于2倍L的2的幂次的长度,便于FFT的计算
    for(i=1;i<2*L;i*=2); L=i;
    for(i=0;i<n/2;i++) swap(a[i],a[n-1-i]);
    for(i=0;i<m/2;i++) swap(b[i],b[m-1-i]);
    //初始化多项式系数
    for(i=0;i<L;i++){
        if(i<n) A[0][0][i].real=a[i]-'0';
        else A[0][0][i].real=0;
        A[0][0][i].img=0;
    }
    for(i=0;i<L;i++){
        if(i<m) A[0][1][i].real=b[i]-'0';
        else A[0][1][i].real=0;
        A[0][1][i].img=0;
    }
    //用FFT加速计算DFT把系数式转为点值式
    FFT(L,0,0,1);FFT(L,1,0,1);
    //计算点值式
    for(i=0;i<L;i++) Y[0][0][i]=Y[0][0][i]*Y[0][1][i];
    //利用FFT计算逆DFT把点值式转化为系数式
    for(i=0;i<L;i++) A[0][0][i]=Y[0][0][i];
    FFT(L,0,0,0);
    //最后处理一下多项式系数转化为大整数时的进位
    int carry=0;n=0;
    for(i=0;i<L;i++){
        carry+=Y[0][0][i].real/L+0.5;
        c[i]='0'+carry%10;
        if(carry) n=i;carry/=10;
    }
    if(carry) c[++n]=carry+'0';
    c[++n]=0;for(i=0;i<n/2;i++) swap(c[i],c[n-1-i]);
}

char a[51011],b[51011],c[120001];
int main(){
    init(N);
    while(scanf("%s%s",a,b)!=EOF){
        cal(a,b,c);
        printf("%s\n",c);
    }
	return 0;
}


你可能感兴趣的:(自己实现用FFT加速多项式计算)