FFT、NTT模板

快速傅里叶

    • 基础知识
    • FFT进行多项式乘法的步骤
    • P3803 【模板】多项式乘法(FFT)
      • FFT写法
      • NTT写法

基础知识

离散傅里叶(DFT)和逆变换(IDFT)时间复杂度都是 O ( n 2 ) O(n^2) O(n2)
快速傅里叶(FFT)时间复杂度为 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)

  • 推导过程:
    A ( x ) = a 0 + a 1 x + a 2 x 2 + ⋯ + a n x n A(x)=a_0+a_1x+a_2x^2+\dots+a_{n}x^n A(x)=a0+a1x+a2x2++anxn,假设n为偶数,根据奇偶性分为:
    A ( x ) = ( a 0 + a 2 x 2 + ⋯ + a n x n ) + ( a 1 x + a 3 x 3 + ⋯ + a n − 1 x n − 1 ) A(x)=(a_0+a_2x^2+\dots+a_nx^n)+(a_1x+a_3x^3+\dots+a_{n-1}x^{n-1}) A(x)=(a0+a2x2++anxn)+(a1x+a3x3++an1xn1)

    设多项式 A 1 ( x ) = a 0 + a 2 x + ⋯ + a n x n 2 A_1(x)=a_0+a_2x+\dots+a_nx^{\frac n2} A1(x)=a0+a2x++anx2n A 2 ( x ) = a 1 + a 3 x + ⋯ + a n − 1 x n 2 − 1 A_2(x)=a_1+a_3x+\dots+a_{n-1}x^{\frac {n}2-1} A2(x)=a1+a3x++an1x2n1
    因此可以得到:
    A ( x ) = A 1 ( x 2 ) + x A 2 ( x 2 ) A(x)=A_1(x^2)+xA_2(x^2) A(x)=A1(x2)+xA2(x2)
    k < n 2 k<\frac n2 k<2n,把 x = w n k x=w_n^k x=wnk代入:
    A ( w n k ) = A 1 ( w n 2 k ) + w n k A 2 ( w n 2 k ) A(w_n^k)=A_1(w_n^{2k})+w_n^kA_2(w_n^{2k}) A(wnk)=A1(wn2k)+wnkA2(wn2k)
    A ( w n k ) = A 1 ( w n 2 k ) + w n k A 2 ( w n 2 k ) A(w_n^k)=A_1(w_{\frac n2}^{k})+w_n^kA_2(w_{\frac n2}^{k}) A(wnk)=A1(w2nk)+wnkA2(w2nk)
    x = w n k + n 2 x=w_n^{k+\frac n2} x=wnk+2n代入可得:
    A ( w n k + n 2 ) = A 1 ( w n 2 k + n ) + w n k + n 2 A 2 ( w n 2 k + n ) A(w_n^{k+\frac n2})=A_1(w_n^{2k+n}) +w_n^{k+\frac n2}A_2(w_n^{2k+n}) A(wnk+2n)=A1(wn2k+n)+wnk+2nA2(wn2k+n)
    A ( w n k + n 2 ) = A 1 ( w n 2 k ) − w n k A 2 ( w n 2 k ) A(w_n^{k+\frac n2})=A_1(w_{\frac n2}^{k}) -w_n^{k}A_2(w_{\frac n2}^{k}) A(wnk+2n)=A1(w2nk)wnkA2(w2nk)
    观察化简后的两式,可知只要知道了 A 1 ( w n 2 k ) A_1(w_{\frac n2}^{k}) A1(w2nk) A 2 ( w n 2 k ) A_2(w_{\frac n2}^{k}) A2(w2nk),就可以求得 A ( w n k + n 2 ) A(w_n^{k+\frac n2}) A(wnk+2n) A ( w n k ) A(w_n^k) A(wnk),也就是通过下层的两个值,求得当前层的两个值

  • 理解:假设多项式 A ( x ) = a 0 + a 1 x 1 + ⋯ + a n − 1 x n − 1 A(x)=a_0+a_1x^1+\dots+a_{n-1}x^{n-1} A(x)=a0+a1x1++an1xn1 为 n-1 次多项式,(n 为 2的倍数,保证了能将圆周等分),将 w n k ( k = 0 , … , n − 1 ) w_n^k (k=0,\dots,n-1) wnk(k=0,,n1)代入后得到离散傅里叶变换 ( b 0 , b 1 , … , b n − 1 ) (b_0,b_1,\dots,b_{n-1}) (b0,b1,,bn1),设为 B ( x ) B(x) B(x) 的系数,即 B ( x ) = b 0 + b 1 x 1 + ⋯ + b n − 1 x n − 1 B(x)=b_0+b_1x^1+\dots+b_{n-1}x^{n-1} B(x)=b0+b1x1++bn1xn1 ,然后再将 w n k ( k = 0 , − 1 , − 2 , … , − ( n − 1 ) ) w_n^k (k=0,-1,-2,\dots,-(n-1)) wnk(k=0,1,2,,(n1)) 代入 B ( x ) B(x) B(x)。得到 C ( x ) = c 0 + c 1 x 1 + ⋯ + c n − 1 x n − 1 C(x)=c_0+c_1x^1+\dots+c_{n-1}x^{n-1} C(x)=c0+c1x1++cn1xn1,通过推导可以得到:
    c i = n a i c_i=na_i ci=nai
    这里只做了两步操作,第一步将系数转成了点值,第二步将点值转成了系数的n倍

推导过程參考博客

FFT进行多项式乘法的步骤

1、对两个多项式补0
2、用FFT计算两个多项式A、B的点值表示法
3、得到乘积多项式C的点值表示法
4、用FFT通过IDFT计算多项式C的系数表示

P3803 【模板】多项式乘法(FFT)

链接:https://www.luogu.com.cn/problem/P3803

题意
在这里插入图片描述

FFT写法

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=1e6+10,mod=1e9+7;
const double PI=acos(-1.0);

struct Complex
{
    double x,y;
    Complex(double x1=0.0,double y1=0.0)
    {
        x=x1,y=y1;
    }
};
Complex operator+(Complex a,Complex b)
{
    return Complex(a.x+b.x,a.y+b.y);
}
Complex operator-(Complex a,Complex b)
{
    return Complex(a.x-b.x,a.y-b.y);
}
Complex operator*(Complex a,Complex b)
{
    return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}

namespace FFT
{
    int total,digit,rev[maxn<<2];
    Complex a[maxn<<2],b[maxn<<2];
    void init(int len)//需要 n+m+1 个位置来表示这个 n+m 次多项式 
    {
        total=1,digit=0;
        while(total<=len)
            total<<=1,digit++;
        for(int i=0;i<total;++i)
            rev[i]=(rev[i>>1]>>1)|(i&1)<<(digit-1);
    }
    void fft(Complex *A,int f)
    {
        for(int i=0;i<total;++i)
            if(i<rev[i]) swap(A[i],A[rev[i]]);
        for(int mid=1;mid<total;mid<<=1)
        {
            Complex W1= Complex(cos(PI/mid),f*sin(PI/mid));
            int len=mid*2;
            for(int p=0;p<total;p+=len)
            {
                Complex Wk= Complex(1,0);
                for(int k=0;k<mid;++k)
                {
                    Complex x=A[p+k],y=Wk*A[p+k+mid];
                    A[p+k]=x+y;
                    A[p+k+mid]=x-y;
                    Wk=Wk*W1;
                }
            }
        }
        if(f==-1)
        {
            for(int i=0;i<total;++i)
                A[i].x=(int)(A[i].x/total+0.5);
        }
    }
    void calc()
    {
        fft(a,1),fft(b,1);
        for(int i=0;i<total;++i)
            a[i]=a[i]*b[i];
        fft(a,-1);
    }
};
using namespace FFT;

int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;++i)
        scanf("%lf",&a[i].x);
    for(int i=0;i<=m;++i)
        scanf("%lf",&b[i].x);
    init(n+m);
    calc();
    for(int i=0;i<=n+m;++i)
        printf("%.0lf%c",a[i].x,i==n+m?'\n':' ');
    return 0;
}

NTT写法

  • 与 FFT 的区别,FFT 是将一个圆周 ( 2 π ) (2\pi) (2π) 等分成 n 份,而这里是将 P-1 等分成 n 份
  • FFT每一步的大小: W 1 = c o s ( 2 π n ) + s i n ( 2 π n ) W_1=cos(\frac {2\pi}n )+sin (\frac{2\pi}n) W1=cos(n2π)+sin(n2π)
  • NTT每一步的大小: W 1 ≡ g p − 1 n m o d   p W_1\equiv g^{\frac {p-1}n} mod\ p W1gnp1mod p
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=1e6+10;

int qpow(int b,int n,int mod)
{
    int res=1;
    while(n>0)
    {
        if(n&1) res=1ll*res*b%mod;
        b=1ll*b*b%mod;
        n>>=1;
    }
    return res;
}

namespace NTT
{
	const int G=3,P=998244353,GI=332748118;
	int total,digit,rev[maxn<<2];
	int a[maxn<<2],b[maxn<<2];
	void init(int len)
	{
		total=1,digit=0;
		while(total<=len)
			total<<=1,digit++;
		for(int i=0;i<total;++i)
		{
			rev[i]=(rev[i>>1]>>1)|(i&1)<<(digit-1);
			a[i]=0,b[i]=0;
		}	
	}
	void ntt(int *A,int f)
	{
		for(int i=0;i<total;++i)
			if(i<rev[i]) swap(A[i],A[rev[i]]);
		for(int mid=1;mid<total;mid<<=1)
		{
			int W1,len=mid*2;;
			if(f==1) W1=qpow(G,(P-1)/len,P);
			else W1=qpow(GI,(P-1)/len,P);
			for(int p=0;p<total;p+=len)
			{
				int Wk=1;
				for(int k=0;k<mid;++k)
				{
					int x=A[p+k],y=1ll*Wk*A[p+k+mid]%P;
					A[p+k]=(1ll*x+y)%P;
					A[p+k+mid]=(1ll*x-y+P)%P; 	
					Wk=1ll*Wk*W1%P;	
				}		
			}	
		}		
		if(f==-1)
		{
			int invp=qpow(total,P-2,P);
			for(int i=0;i<total;++i)
				a[i]=1ll*a[i]*invp%P; 
		}
	}	
	void calc()
	{
		ntt(a,1),ntt(b,1);
		for(int i=0;i<total;++i)
			a[i]=1ll*a[i]*b[i]%P;
		ntt(a,-1);
	}
};
using namespace NTT;

int main()
{	
	int n,m;	
	scanf("%d%d",&n,&m);
	init(n+m);
	for(int i=0;i<=n;++i)
		scanf("%d",&a[i]);
	for(int i=0;i<=m;++i)
		scanf("%d",&b[i]);
	calc();
	for(int i=0;i<=n+m;++i)
		printf("%d%c",a[i],i==n+m?'\n':' ');
	return 0;
}

你可能感兴趣的:(多项式,数学,FFT)