任意模数快速傅立叶变换的两种方法

1.三模数NTT

对于初值在 P P P范围内的序列 A ( x ) 和 B ( x ) A(x) 和B(x) A(x)B(x),一次卷积之后大小不超过 n P 2 nP^2 nP2。找三个数论模数分别NTT之后,用中国剩余定理合并。不用大数或者__int128,可以参考下面的做法。
https://blog.csdn.net/u014609452/article/details/68058602

板子题:P4245 【模板】任意模数NTT

code:

// luogu-judger-enable-o2
#include 

using namespace std;
typedef long long ll;

const int maxn = 2e6 + 10 , g = 3;
const double eps = 1e-3;
int mod;
int rev[maxn];

ll qmul(ll a, ll b, ll c){
    a %= c;    b %= c;
    ll ret = a * b - (ll)((long double)a * b / c + eps) * c;
    return ret < 0 ? ret + c : ret;
}

inline ll qpow(ll a,ll b,ll P){
    ll ret = 1;
    a %= P;
    for(;b;b>>=1,a=a*a%P) if(b&1) ret = ret * a % P;
    return ret;
}

const int m1 = 998244353,m2 = 1004535809,m3 = 469762049;
const ll _M = (ll)m1 * m2;
const int inv1 = qpow(m1 % m2,m2-2,m2);
const int inv2 = qpow(m2 % m1,m1-2,m1);
const int inv12 = qpow(_M % m3,m3-2,m3);


ll CRT(ll a1, ll a2, ll a3){
    ll ret = qmul(a1 * m2 % _M, inv2, _M); 
    (ret += qmul(a2 * m1 % _M, inv1, _M)) %= _M;
    ll ans = ((a3 - ret) % m3 + m3) % m3 * inv12 % m3;
    ans = (ans % mod * (_M % mod) % mod + ret % mod) % mod;
    return ans;
}

struct NTT{
    int P;
    int num,w[2][maxn];
    void Pre(int _P,int m){
        num = m; P = _P;
        int wn = qpow(g,(P-1)/num,P);
        int _wn = qpow(wn,P-2,P);
        w[1][0] = w[0][0] = 1;
        for(int i = 1;i<num;i++) w[1][i] = (ll)w[1][i-1] * wn % P;
        for(int i = 1;i<num;i++) w[0][i] = (ll)w[0][i-1] * _wn % P;
    }
    void DFT(int* a,int N,int r){
        for(int i = 1;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
        for(int i = 1;i<N;i<<=1)
            for(int j = 0;j<N;j+=(i<<1))
                for(int k = 0;k<i;k++){
                    // cout<<"r = "<
                    // cout<<"ijk = "<
                    int x = a[j+k],y = 1LL * a[i+j+k] * w[r][num/(i<<1)*k] % P;
                    a[j+k] = (x+y) % P; a[i+j+k] = (x + P - y) % P;
        }
        if(!r) for(int i = 0,Inv = qpow(N,P-2,P);i<N;i++) a[i] = 1LL * a[i] * Inv % P;
    }
}ntt[3];

int A[maxn],B[maxn],C[maxn],D[maxn],tmp[3][maxn];

int main(){
    int n,m;
    scanf("%d%d%d",&n,&m,&mod);
    for(int i = 0;i<=n;i++) scanf("%d",&A[i]);
    for(int i = 0;i<=m;i++) scanf("%d",&B[i]);
    int N;
    for(N=1;N<=(n+m+1);N<<=1);
    ntt[0].Pre(m1,N);
    ntt[1].Pre(m2,N);
    ntt[2].Pre(m3,N);
    int L = 0; while(!(N>>L&1)) L++; L--;
    for(int i = 1;i<N;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<L);
    for(int i = 0;i<3;i++){
        memcpy(C,A,sizeof(int) * (N+1));
        memcpy(D,B,sizeof(int) * (N+1));
        ntt[i].DFT(C,N,1);
        ntt[i].DFT(D,N,1);
        for(int j = 0;j<N;j++) {
            tmp[i][j] = (ll)C[j] * D[j] % ntt[i].P;
            // cout<
        }
        ntt[i].DFT(tmp[i],N,0); 
    }
    for(int i = 0;i< n + m + 1;i++) {
        printf("%lld ",CRT(tmp[0][i],tmp[1][i],tmp[2][i]));
    }
    return 0;
}

2. 拆系数FFT

又称MTT
大概就是把系数拆成 f ( x ) = P ∗ k ( x ) + r ( x ) f(x) = \sqrt{P}*k(x)+r(x) f(x)=P k(x)+r(x)的形式,然后再还原回去。朴素的版本一共需要7次 D F T DFT DFT。卷积之后数据在 n P nP nP级别,为避免精度误差,需要用到 l o n g   d o u b l e long~double long double,并预处理单位方根。

板子题:P4245 【模板】任意模数NTT

code:不知道为什么跑得比三模数NTT慢啊

// luogu-judger-enable-o2
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

const int maxn = 6e5 + 10;
const long double pi = acos((long double)-1.0);
const double eps = 1e-3;
typedef long long ll;
int mod,M;

struct cp{
    long double r,i;
    cp(long double _r=0,long double _i=0):r(_r),i(_i){}
    cp operator + (cp x) { return cp(r+x.r,i+x.i);}
    cp operator - (cp x) { return cp(r-x.r,i-x.i); }
    cp operator * (cp x) { return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
    cp conj() {return cp(r,-i); }
};

int A[maxn],B[maxn],rev[maxn];
cp a[maxn],b[maxn],k1[maxn],k2[maxn],r1[maxn],r2[maxn];
cp s1[maxn],s2[maxn],s3[maxn],w[2][maxn];
int ans[maxn];

void get_wn(int N){
    for(int i = 0;i<N;i++){
        w[1][i] = cp(cos(2*pi/N * i),sin(2*pi/N*i));
        w[0][i] = w[1][i].conj();
    }
}

void DFT(cp* a,int N,int r){
    for(int i = 0;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int l = 2; l<=N;l<<=1){
        int m = l >> 1;
        for(int j = 0; j <= N; j += l)
            for(int k = 0;k<m;k++){
                cp u = w[r][N/l*k] * a[j+k+m];
                a[j+k+m] = a[j+k] - u;
                a[j+k] = a[j + k] + u;
            }
    }
    if(r == 0) for(int i = 0;i<N;i++) a[i].r /= N;
}

void MTT(cp* a,cp* b,int N){
    for(int i = 0;i<N;i++){
        k1[i] = cp(A[i]/M,0); r1[i] = cp(A[i]%M,0);
        k2[i] = cp(B[i]/M,0); r2[i] = cp(B[i]%M,0);
    }
    DFT(k1,N,1); DFT(r1,N,1);
    DFT(k2,N,1); DFT(r2,N,1);
    for(int i = 0;i<N;i++) {
        s1[i] = k1[i] * k2[i];
        s2[i] = k1[i] * r2[i] + k2[i] * r1[i];
        s3[i] = r1[i] * r2[i];
    }
    DFT(s1,N,0); DFT(s2,N,0); DFT(s3,N,0);
    for(int i = 0;i<N;i++){
        int x1 = (ll)(s1[i].r + 0.5) % mod * M * M % mod;
        int x2 = (ll)(s2[i].r + 0.5) % mod * M % mod;
        int x3 = (ll)(s3[i].r + 0.5) % mod;
        ans[i] = (x1 + x2);
		if(ans[i]>=mod) ans[i]-=mod;
		ans[i] += x3;
		if(ans[i]>=mod) ans[i]-=mod;
    }
}

int main(){
    int n,m;
    scanf("%d%d%d",&n,&m,&mod);
    M = sqrt(mod + eps);
    for(int i = 0;i<= n;i++) scanf("%d",&A[i]);
	for(int i = 0;i<= m;i++) scanf("%d",&B[i]);
    int L,N;
    for(L=0,N=1;N<=(n+m+2);L++,N<<=1);
    for(int i = 1;i<N;i++) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
    get_wn(N);
    MTT(a,b,N);
    for(int i = 0;i<n+m+1;i++){
        printf("%d ",ans[i]);
    }
    return 0;
}

你可能感兴趣的:(fft&&ntt)