任意模数FFT 板子

这个任意模数的东西是myy论文里写的(orz myy),论文里myy提到了很多优化常数的方法,能将其DFT次数优化到4次甚至3.5次,然而我看不懂
于是我打了7次DFT的版本,但是实测好像不慢不知道是不是数据的问题
大致思想是拆因数,将一个数拆成 aM+b ,然后对两个a和两个b做4次DFT,乘到3个c数组里,再对3个c做3次IDFT,最后把答案转回来
因为把因数拆了,所以卷积结果每个数的大小都不算太大,就不会被卡精度,所以DFT部分直接用FFT

测试题:BZOJ3992(NTT和这个FFT都能A,但直接上裸的FFT精度就炸了过不了)
核心部分:

mtt operator * (mtt x,mtt y)
{
    for(int i=0;ix=x.s[i]/qmod; a1[i].y=0.0;
        a2[i].x=x.s[i]%qmod; a2[i].y=0.0;
        b1[i].x=y.s[i]/qmod; b1[i].y=0.0;
        b2[i].x=y.s[i]%qmod; b2[i].y=0.0;
    }
    FFT(a1,1); FFT(a2,1);
    FFT(b1,1); FFT(b2,1);
    for(int i=0;i*b1[i];
        c2[i]=a2[i]*b1[i]+a1[i]*b2[i];
        c3[i]=a2[i]*b2[i];
    }
    FFT(c1,-1); FFT(c2,-1); FFT(c3,-1);
    mtt ret=zero;
    for(int i=0;ix+0.5)%Mod*qmod%Mod*qmod%Mod + 
                    (ll)(c2[i].x+0.5)%Mod*qmod%Mod 
                    + (ll)(c3[i].x+0.5)%Mod)%Mod;

        if(i<m) ret.s[i]+=temp;
        else ret.s[_to[i]]+=temp;
    }
    for(int i=0;is[i]%=Mod;
    return ret;
}


全部代码(BZOJ3992):

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define ll long long
using namespace std;

const int maxn = 20000;
const ll Mod = 1004535809;
const double pi=acos(-1);

int xn,ansn,n,N,ln,id[maxn];
int a[maxn],m,len;
ll qmod;

ll pw(ll x,int k,ll mod)
{
    x%=mod;
    ll ret=1,tmp=x;
    int nowk=0,tmpk=1;
    while(nowk!=k)
    {
        if(tmpk&k)
        {
            ret=ret*tmp%mod;
            nowk|=tmpk;
        }
        tmpk<<=1;tmp=tmp*tmp%mod;
    }
    return ret;
}

struct E
{
    double x,y;
    E(){x=y=0.0;}
    E(double _x,double _y){x=_x;y=_y;}
}a1[maxn],a2[maxn],b1[maxn],b2[maxn],c1[maxn],c2[maxn],c3[maxn],w[maxn];
E operator +(E x,E y){return E(x.x+y.x,x.y+y.y);}
E operator -(E x,E y){return E(x.x-y.x,x.y-y.y);}
E operator *(E x,E y){return E(x.x*y.x-x.y*y.y,x.y*y.x+x.x*y.y);}

struct mtt
{
    ll s[maxn];
}zero; int _to[maxn];
void FFT(E *s,int sig)
{
    for(int i=0;iif(is[i],s[id[i]]);
    for(int mm=2;mm<=n;mm<<=1)
    {
        int t=mm>>1,tt=n/mm;
        for(int i=0;i1?w[i*tt]:w[n-i*tt];
            for(int j=i;js[j],ty=s[j+t]*wn;
                s[j]=tx+ty;
                s[j+t]=tx-ty;
            }
        }
    }
    if(sig==-1) for(int i=0;is[i].x/=(double)n;
}
mtt operator * (mtt x,mtt y)
{
    for(int i=0;ix=x.s[i]/qmod; a1[i].y=0.0;
        a2[i].x=x.s[i]%qmod; a2[i].y=0.0;
        b1[i].x=y.s[i]/qmod; b1[i].y=0.0;
        b2[i].x=y.s[i]%qmod; b2[i].y=0.0;
    }
    FFT(a1,1); FFT(a2,1);
    FFT(b1,1); FFT(b2,1);
    for(int i=0;i*b1[i];
        c2[i]=a2[i]*b1[i]+a1[i]*b2[i];
        c3[i]=a2[i]*b2[i];
    }
    FFT(c1,-1); FFT(c2,-1); FFT(c3,-1);
    mtt ret=zero;
    for(int i=0;ix+0.5)%Mod*qmod%Mod*qmod%Mod + 
                    (ll)(c2[i].x+0.5)%Mod*qmod%Mod 
                    + (ll)(c3[i].x+0.5)%Mod)%Mod;

        if(i<m) ret.s[i]+=temp;
        else ret.s[_to[i]]+=temp;
    }
    for(int i=0;is[i]%=Mod;
    return ret;
}

int p[maxn],pn,ind[maxn];
void divide(int x)
{
    int t=sqrt(x*1.0);pn=0;
    for(int i=2;i<=t;i++)
    {
        if(x%i==0)
        {
            p[++pn]=i;
            while(x%i==0)x/=i;
        }
    }
    if(x>1)p[++pn]=x;
}
int get_g(int n)
{
    divide(n-1);
    for(int i=2;ifor(int j=1;j<=pn;j++)
        {
            if(pw(i,(n-1)/p[j],n)==1){flag=false; break;}
        }
        if(flag) return i;
    }
}

mtt get_ans(mtt x,int k)
{
    mtt ret,tmp=x;
    int nowk=0,tmpk=1;
    while(nowk!=k)
    {
        if(tmpk&k)
        {
            if(nowk==0)ret=tmp;
            else ret=ret*tmp;
            nowk|=tmpk;
        }
        tmpk<<=1; tmp=tmp*tmp;
    }
    return ret;
} 

int main()
{
    scanf("%d%d%d%d",&ansn,&m,&xn,&len);
    for(int i=1;i<=len;i++) scanf("%d",&a[i]);

    for(n=1,ln=0;n<(m+m);n<<=1,ln++);
    for(int i=0;i>1]>>1)|((i&1)<<(ln-1));
    for(int mm=2;mm<=n;mm<<=1)
    {
        int t=mm>>1,tt=n/mm;
        for(int i=0;i*tt]=E(cos(i*2*pi/mm),sin(i*2*pi/mm));
            w[n-i*tt]=E(cos(i*2*pi/mm),sin(-i*2*pi/mm));
        }
    }

    ll mg=get_g(m);
    for(ll tmp=mg,i=1;iif(i<m) ind[tmp]=i;
        else _to[i]=ind[tmp];
        tmp=tmp*mg%m;
    }

    mtt st,zero; for(int i=0;is[i]=0;
    st=zero;
    for(int i=1;i<=len;i++) if(a[i]!=0) st.s[ind[a[i]]]=1;
    qmod=sqrt(Mod*1.0);
    mtt ret=get_ans(st,ansn);
    printf("%lld\n",(ret.s[ind[xn]]%Mod+Mod)%Mod);

    return 0;
}

把单位复数根预处理后真的挺快的,不知道为什么比我的NTT还快明明对拍的话是NTT快的

参考资料: 《再探快速傅里叶变换》 by 雅礼中学 毛啸

你可能感兴趣的:(板子,快速傅里叶变换(FFT))