快速傅里叶变换

自己写的课件公式太多不好弄上来,还是算了。
贴两个模板。一个FFT一个NTT,都是UOJ#34的。

#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
const double pi=acos(-1.0);
struct complex{//建议手封一个复数,比系统自带快400ms以上
    double r,i;
    complex(){}
    complex(double _r,double _i){r=_r;i=_i;}
}a[262145],b[262145];//注意a,b,r这3个数组的大小
int n,m,L,r[262145];
inline complex operator + (const complex &x,const complex &y)
{
    return complex(x.r+y.r,x.i+y.i);
}
inline complex operator - (const complex &x,const complex &y)
{
    return complex(x.r-y.r,x.i-y.i);
}
inline complex operator * (const complex &x,const complex &y)
{
    return complex(x.r*y.r-x.i*y.i,x.r*y.i+x.i*y.r);
}
void fft(complex *a,int f)
{
    for(int i=0;i<n;++i)if(i<r[i])swap(a[i],a[r[i]]);
    for(int i=1;i<n;i<<=1)
    {
        complex wn(cos(pi/i),f*sin(pi/i));
        for(int j=0;j<n;j+=i<<1)
        {
            complex w(1,0);
            for(int k=0;k<i;++k,w=w*wn)
            {
                complex x=a[j+k],y=w*a[j+k+i];//注意y要乘w这个很容易忘
                a[j+k]=x+y,a[j+k+i]=x-y;
            }
        }
    }
    if(f==-1)for(int i=0;i<n;++i)a[i].r/=n;//注意要除以n这个很容易忘
}
int main()
{
    scanf("%d%d",&n,&m);//为了方便记忆的话,两个数组的起点是0,终点是n or m,然后主函数里面for循环全部写<=不会错
    for(int i=0;i<=n;++i)scanf("%lf",&a[i].r);
    for(int i=0;i<=m;++i)scanf("%lf",&b[i].r);
    m=n+m;
    for(n=1;n<=m;n<<=1)++L;
    for(int i=0;i<=n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
    fft(a,1);fft(b,1);
    for(int i=0;i<=n;++i)a[i]=a[i]*b[i];
    fft(a,-1);
    for(int i=0;i<=m;++i)printf("%d%c",(int)(a[i].r+0.5),i==m?'\n':' ');
}
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const ll mod=1005060097,G=5;
ll a[262145],b[262145];
int n,m;
ll pow(ll x,ll y)
{
    ll ret=1;
    while(y)
    {
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return ret;
}
void ntt(ll*a,int n,int f)
{
    for(int i=1,j=0;i<n-1;++i)
    {
        for(int d=n;j^=d>>=1,~j&d;);
        if(i<j)swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1)
    {
        ll wn=pow(G,(mod-1)/(i<<1));
        for(int j=0;j<n;j+=i<<1)
        {
            ll w=1;
            for(int k=0;k<i;++k,w=w*wn%mod)
            {
                ll x=a[j+k],y=w*a[j+k+i]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+i]=(x-y+mod)%mod;
            }
        }
    }
    if(f==-1)
    {
        reverse(a+1,a+n);
        ll inv=pow(n,mod-2);
        for(int i=0;i<n;++i)a[i]=a[i]*inv%mod;
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;++i)scanf("%lld",&a[i]);
    for(int i=0;i<=m;++i)scanf("%lld",&b[i]);
    for(m=n+m,n=1;n<=m;n<<=1);
    ntt(a,n,1);ntt(b,n,1);
    for(int i=0;i<=n;++i)a[i]=a[i]*b[i]%mod;
    ntt(a,n,-1);
    for(int i=0;i<=m;++i)printf("%lld ",a[i]);
}

分治fft模板,不过话说回来分治fft可以被多项式求逆替代

#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
const int mod=786433;
int n,m,c,cas,ans,fac[500010],f[500010],A[462145],B[462145];
ll p[20],g[20],ng[20];
ll pow(ll x,ll y)
{
    ll res=1;
    while(y)
    {
        if(y&1)res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void init()
{
    fac[0]=p[0]=g[0]=ng[0]=1;
    for(int i=1;i<=18;++i)
    {
        p[i]=p[i-1]<<1;
        g[i]=pow(10,(mod-1)/p[i]);
        ng[i]=pow(g[i],mod-2);
    }
    for(int i=1;i<=100000;++i)f[i]=fac[i]=1ll*i*fac[i-1]%mod;
}
void ntt(int *a,int n,int f)
{
    for(int i=1,j=0;i<n-1;++i)
    {
        for(int s=n;j^=s>>=1,~j&s;);
        if(i<j)swap(a[i],a[j]);
    }
    for(int len=n,now=1,d=1;len>1;len>>=1,++now,d<<=1)
    {
        int wn=f?g[now]:ng[now];
        for(int i=0;i<n;i+=d<<1)
            for(int j=0,w=1;j<d;++j,w=1ll*w*wn%mod)
            {
                int tmp=1ll*a[i+j+d]*w%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
    }
}
void solve(int l,int r)
{
    if(l==r)return;
    int mid=(l+r)>>1;
    solve(l,mid);
    int k=1,len=r-l+1;
    for(;k<len;k<<=1);
    for(int i=0;i<k;++i)A[i]=i<=(mid-l)?f[i+l]:0;
    for(int i=0;i<k;++i)B[i]=fac[i+1];
    ntt(A,k,1);ntt(B,k,1);
    for(int i=0;i<k;++i)A[i]=1ll*A[i]*B[i]%mod;
    ntt(A,k,0);
    ll inv=pow(k,mod-2);
    for(int i=0;i<k;++i)A[i]=1ll*A[i]*inv%mod;
    for(int i=mid+1;i<=r;++i)f[i]=(f[i]-A[i-l-1]+mod)%mod;
    solve(mid+1,r);
}
int main()
{
    init();
    solve(1,100000);
    scanf("%d",&cas);
    while(cas--)
    {
        ans=1;
        scanf("%d%d",&n,&m);
        for(int i=1;i<=m;++i)
        {
            scanf("%d",&c);
            int L=n,R=0;
            for(int j=1,x;j<=c;++j)
            {
                scanf("%d",&x);
                L=min(L,x);
                R=max(R,x);
            }
            if(R-L+1!=c)ans=0;
            else ans=1ll*ans*f[c]%mod;
        }
        printf("%d\n",ans);
    }
}

所以来个多项式求逆模板吧:

#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 262144
#define ll long long
using namespace std;
const ll mod=1005060097;
ll a[N],b[N];
int n,m,k,len;
ll pow(ll x,ll y)
{
    ll ret=1;
    while(y)
    {
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return ret;
}
void NTT(ll *a,int len,int type)
{
    for(int i=0,t=0,j=0;i<len;++i)
    {
        if(i>t)swap(a[i],a[t]);
        for(j=(len>>1);(t^=j)<j;j>>=1);
    }
    for(int h=2;h<=len;h<<=1)
    {
        ll wn=pow(5,(mod-1)/h);
        for(int i=0;i<len;i+=h)
        {
            ll w=1;
            for(int j=0;j<(h>>1);++j,w=w*wn%mod)
            {
                ll temp=w*a[i+j+(h>>1)]%mod;
                a[i+j+(h>>1)]=(a[i+j]-temp+mod)%mod;
                a[i+j]=(a[i+j]+temp)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<(len>>1);++i)swap(a[i],a[len-i]);
        ll inv=pow(len,mod-2);
        for(int i=0;i<len;++i)a[i]=a[i]*inv%mod;
    }
}
void inv(ll *a,int len)
{
    if(len==1)
    {
        b[0]=pow(a[0],mod-2);
        return;
    }
    inv(a,len>>1);
    static ll temp[N];
    memcpy(temp,a,sizeof(ll)*(len>>1));
    NTT(temp,len,1);
    NTT(b,len,1);
    for(int i=0;i<len;++i)b[i]=b[i]*(2-temp[i]*b[i]%mod+mod)%mod;
    NTT(b,len,-1);
    memset(b+(len>>1),0,sizeof(ll)*(len>>1));
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;++i)
    {
        scanf("%d",&k);
        if(k<=n)--a[k];
        a[k]=(mod+a[k])%mod;
    }
    a[0]=1;
    for(len=1;len<=n+n+3;len<<=1);
    inv(a,len);
    printf("%lld\n",b[n]);
}

下面发个最终版本的fft和ntt,优点在于接口简洁,外部几乎没有任何预处理操作(当然你要把做fft和ntt的长度搞成2的幂,ntt的时候质数和原根要找好)。
fft

struct complex{
    double r,i;
    complex(double _r=0,double _i=0){r=_r,i=_i;}
    complex operator+(complex rhs){return complex(r+rhs.r,i+rhs.i);}
    complex operator-(complex rhs){return complex(r-rhs.r,i-rhs.i);}
    complex operator*(complex rhs){return complex(r*rhs.r-i*rhs.i,i*rhs.r+r*rhs.i);}
}A[N],B[N],C[N];
const double pi=acos(-1.0);
void fft(complex*a,int n,int f)
{
    for(int i=1,j=0;i<n-1;++i)
    {
        for(int d=n;j^=d>>=1,~j&d;);
        if(i<j)swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1)
    {
        complex wn(cos(pi/i),sin(pi/i)*f);
        for(int j=0;j<n;j+=i<<1)
        {
            complex w(1,0);
            for(int k=0;k<i;++k,w=w*wn)
            {
                complex x=a[j+k],y=w*a[j+k+i];
                a[j+k]=x+y;
                a[j+k+i]=x-y;
            }
        }
    }
    if(f==-1)for(int i=0;i<n;++i)a[i].r/=n;
}

另一个版本是ifft处理方式类似ntt的fft

struct complex{
    double r,i;
    complex(double _r=0,double _i=0){r=_r,i=_i;}
    complex operator+(complex rhs){return complex(r+rhs.r,i+rhs.i);}
    complex operator-(complex rhs){return complex(r-rhs.r,i-rhs.i);}
    complex operator*(complex rhs){return complex(r*rhs.r-i*rhs.i,i*rhs.r+r*rhs.i);}
}A[N],B[N],C[N];
const double pi=acos(-1.0);
void fft(complex*a,int n,int f)
{
    for(int i=1,j=0;i<n-1;++i)
    {
        for(int d=n;j^=d>>=1,~j&d;);
        if(i<j)swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1)
    {
        complex wn(cos(pi/i),sin(pi/i));
        for(int j=0;j<n;j+=i<<1)
        {
            complex w(1,0);
            for(int k=0;k<i;++k,w=w*wn)
            {
                complex x=a[j+k],y=w*a[j+k+i];
                a[j+k]=x+y;
                a[j+k+i]=x-y;
            }
        }
    }
    if(f==-1)
    {
        reverse(a+1,a+n);
        for(int i=0;i<n;++i)a[i].r/=n;
    }
}

ntt

const ll mod=1005060097,G=5;
ll pow(ll x,ll y)
{
    ll ret=1;
    while(y)
    {
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return ret;
}
void ntt(ll*a,int n,int f)
{
    for(int i=1,j=0;i<n-1;++i)
    {
        for(int d=n;j^=d>>=1,~j&d;);
        if(i<j)swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1)
    {
        ll wn=pow(G,(mod-1)/(i<<1));
        for(int j=0;j<n;j+=i<<1)
        {
            ll w=1;
            for(int k=0;k<i;++k,w=w*wn%mod)
            {
                ll x=a[j+k],y=w*a[j+k+i]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+i]=(x-y+mod)%mod;
            }
        }
    }
    if(f==-1)
    {
        reverse(a+1,a+n);
        ll inv=pow(n,mod-2);
        for(int i=0;i<n;++i)a[i]=a[i]*inv%mod;
    }
}

再来一发ntt mod anynumber
比如我们要模1e9+7,这个显然直接上ntt不能做。
所以我们可以嘿嘿嘿
找3个1e9左右的费马质数,比如998244353,1005060097,950009857,然后对每个质数都做一次ntt,用CRT合并解。
注意到CRT的过程要把模数乘起来,这样long long不是很资瓷,于是需要手写一个高精度或者int128.
下面放一个HDU 5519
指数生成函数+ntt mod anynumber
这个题我学会了一个东西就是在写ntt的时候,遇到多个多项式相乘,我们可以把这n个多项式全部ntt,然后相乘,最后逆变换。但是不知道为什么ntt mod anynumber不行,必须一个个地乘,猜测是CRT这里有问题。
所以这里还有一个思想,就是启发式合并。对于n个多项式,我们开一个集合维护它们的长度。每次选取最短的两个进行ntt,然后将得到的新多项式塞集合。这样显然比较快。
对于本题,启发式合并就不需要了了,但是也不应该用一个数组保存解,一直乘下去,这样明显不优。应该两个两个地乘,再把每一对的乘积拿来乘。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const ll M=131073ll,MOD=1000000007ll,P[3]={998244353ll,1005060097ll,950009857ll},G[3]={3ll,5ll,7ll},Inv[3]={644348675ll,675933219ll,647895261ll};
int n,m,cas,a0,a1,a2,a3,a4;
ll inv[M],fac[M];
inline ll pow(ll x,ll y,ll P)
{
    ll re=1;
    while(y)
    {
        if(y&1)re=re*x%P;
        x=x*x%P;
        y>>=1;
    }
    return re;
}
struct Int_128{
    unsigned long long a,b;
    Int_128(ll x){a=0,b=x;}
    friend bool operator < (Int_128 x,Int_128 y)
    {
        return x.a<y.a||x.a==y.a&&x.b<y.b;
    }
    friend Int_128 operator + (Int_128 x,Int_128 y)
    {
        Int_128 re(0);
        re.a=x.a+y.a+(x.b+y.b<x.b);
        re.b=x.b+y.b;
        return re;
    }
    friend Int_128 operator - (Int_128 x,Int_128 y)
    {
        y.a=~y.a;y.b=~y.b;
        return x+y+1;
    }
    void Div2()
    {
        b>>=1;b|=(a&1ll)<<63;a>>=1;
    }
    friend Int_128 operator * (Int_128 x,Int_128 y)
    {
        Int_128 re=0;
        while(y.a||y.b)
        {
            if(y.b&1)re=re+x;
            x=x+x;y.Div2();
        }
        return re;
    }
    friend Int_128 operator % (Int_128 x,Int_128 y)
    {
        Int_128 temp=y;
        int cnt=0;
        while(temp<x)temp=temp+temp,++cnt;
        for(;cnt>=0;cnt--)
        {
            if(temp<x)x=x-temp;
            temp.Div2();
        }
        return x;
    }
};
void ntt(ll*a,int n,int f,int P,int G)
{
    for(int i=1,j=0;i<n-1;++i)
    {
        for(int d=n;j^=d>>=1,~j&d;);
        if(i<j)swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1)
    {
        ll wn=pow(G,(P-1)/(i<<1),P);
        for(int j=0;j<n;j+=i<<1)
        {
            ll w=1;
            for(int k=0;k<i;++k,w=w*wn%P)
            {
                ll x=a[j+k],y=w*a[j+k+i]%P;
                a[j+k]=(x+y)%P;
                a[j+k+i]=(x-y+P)%P;
            }
        }
    }
    if(f==-1)
    {
        for(int i=1;i<(n>>1);++i)swap(a[i],a[n-i]);
        ll inv=pow(n,P-2,P);
        for(int i=0;i<n;++i)a[i]=a[i]*inv%P;
    }
}
inline void Polynomial_Multiplication(ll a[],ll b[],ll c[],int n)
{
    static ll A[3][M],B[3][M];
    for(int j=0;j<3;++j)
    {
        for(int i=0;i<n;++i)
        {
            A[j][i]=a[i]%P[j];
            B[j][i]=b[i]%P[j];
        }
        ntt(A[j],n,1,P[j],G[j]);
        ntt(B[j],n,1,P[j],G[j]);
        for(int i=0;i<n;i++)A[j][i]=A[j][i]*B[j][i]%P[j];
        ntt(A[j],n,-1,P[j],G[j]);
    }
    Int_128 _MOD=Int_128(P[0]*P[1])*P[2];
    for(int i=0;i<n;++i)
    {
        Int_128 temp=
        Int_128(P[1]*P[2])*Int_128(Inv[0]*A[0][i])+
        Int_128(P[0]*P[2])*Int_128(Inv[1]*A[1][i])+
        Int_128(P[0]*P[1])*Int_128(Inv[2]*A[2][i]);
        c[i]=(temp%_MOD%MOD).b;
    }
}
int main()
{
    scanf("%d",&cas);
    inv[0]=inv[1]=fac[0]=fac[1]=1;
    for(int i=2;i<=15000;++i)inv[i]=(MOD-MOD/i)*inv[MOD%i]%MOD,fac[i]=fac[i-1]*i%MOD;
    for(int i=2;i<=15000;++i)inv[i]=inv[i]*inv[i-1]%MOD;
    static ll A[M],B[M],C[M],D[M],E[M],F[M];
    for(int t=1;t<=cas;++t)
    {
        scanf("%d%d%d%d%d%d",&n,&a0,&a1,&a2,&a3,&a4);
        for(int i=0;i<=min(n,a0);++i)A[i]=inv[i];
        for(int i=min(n,a0)+1;i<=m;++i)A[i]=0;
        for(int i=0;i<=min(n,a1);++i)B[i]=inv[i];
        for(int i=min(n,a1)+1;i<=m;++i)B[i]=0;
        for(int i=0;i<=min(n,a2);++i)C[i]=inv[i];
        for(int i=min(n,a2)+1;i<=m;++i)C[i]=0;
        for(int i=0;i<=min(n,a3);++i)D[i]=inv[i];
        for(int i=min(n,a3)+1;i<=m;++i)D[i]=0;
        for(int i=0;i<=min(n,a4);++i)E[i]=inv[i];
        for(int i=min(n,a4)+1;i<=m;++i)E[i]=0;
        /*运用了启发式合并的思想,跑了5.1s*/
        for(m=1;m<min(a1,n)+min(a2,n);m<<=1);
        Polynomial_Multiplication(B,C,B,m);
        for(m=1;m<min(a3,n)+min(a4,n);m<<=1);
        Polynomial_Multiplication(D,E,D,m);
        for(m=1;m<min(a1,n)+min(a2,n)+min(a3,n)+min(a4,n);m<<=1);
        Polynomial_Multiplication(B,D,B,m);

    /* 如果不用启发式合并,直接B数组乘到底,5.8s,已经是卡着过了 for(m=1;m<min(a1,n)+min(a2,n);m<<=1); Polynomial_Multiplication(B,C,B,m); for(m=1;m<min(a1,n)+min(a2,n)+min(a3,n);m<<=1); Polynomial_Multiplication(B,D,B,m); for(m=1;m<min(a1,n)+min(a2,n)+min(a3,n)+min(a4,n);m<<=1); Polynomial_Multiplication(B,E,B,m); */  for(m=1;m<min(a0,n)+min(a1,n)+min(a2,n)+min(a3,n)+min(a4,n);m<<=1);
        if(!a0)
        {
            Polynomial_Multiplication(A,B,F,m);
            printf("Case #%d: %I64d\n",t,F[n]*fac[n]%MOD);
        }
        else
        {
            Polynomial_Multiplication(A,B,F,m);
            ll ans=F[n]*fac[n]%MOD;
            A[min(a0,n)]=0;
            Polynomial_Multiplication(A,B,F,m);
            printf("Case #%d: %I64d\n",t,((ans-F[n-1]*fac[n-1]%MOD)%MOD+MOD)%MOD);
        }
    }
}

你可能感兴趣的:(fft,ntt)