「学习笔记」FFT 之优化——NTT

文章目录

  • 「学习笔记」FFT 之优化——NTT
    • 前言
    • 引入
    • 快速数论变换——NTT
    • 一些引申问题及解决方法
      • 三模数 NTT
      • 拆系数 FFT (MTT)

「学习笔记」FFT 之优化——NTT

前言

N T T NTT NTT 在某种意义上说,应该属于 F F T FFT FFT 的一种优化。

——因而必备知识肯定要有 F F T FFT FFT 啦…

如果不知道 F F T FFT FFT 的大佬可以走这里

引入

F F T FFT FFT 中,为了能计算单位原根 ω \omega ω ,我们使用了 C++ \text{C++} C++math 库中的 c o s 、 s i n cos、sin cossin 函数,所以我们无法避免地使用了 double 以及其运算。

但是,众所周知的, double 的运算很慢,并且,我们的虚数乘法是类似于下面这种打法:

cplx operator * (const cplx a)const{return cplx(vr*a.vr-vi*a.vi,vr*a.vi+a.vr*vi);}

显然,一次虚数乘法涉及四次 double 的乘法。

并且在运算过程中,会有大量的精度丢失,这都是我们不可接受的。

然而问题来了:我们多项式乘法都是整数在那里搞来搞去,为什么一定要扯到浮点数。是否存在一个在模意义下的,只使用整数的方法?——Tiw_Air_OAO

快速数论变换——NTT

想一想我们使用了单位复根的哪些特性:

  1. w n i ∗ w n j = w n i + j w_{n}^{i}*w_{n}^{j}=w_{n}^{i+j} wniwnj=wni+j
  2. w d n d k = w n k w_{dn}^{dk}=w_n^k wdndk=wnk
  3. w 2 n k = − w 2 n k + n w_{2n}^k=-w_{2n}^{k+n} w2nk=w2nk+n
  4. n n n 个单位根互不相同,且 w n 0 = 1 w_n^0=1 wn0=1

那么我们能否在 模意义 下找到一个性质相同的数?

这里有一个同样也是 某某根 的东西,叫做 原根

对于素数 p p p p p p 的原根 G G G 定义为使得 G 0 , G 1 , . . . , G p − 2 ( m o d   p ) G^0,G^1,...,G^{p−2}(mod\space p) G0,G1,...,Gp2(mod p) 互不相同的数。

仔细思考一下,发现 原根单位复根 很像。

同理,我们再定义 g n k = ( G p − 1 n ) k g_n^k = (G^{\frac{p-1}{n}})^k gnk=(Gnp1)k ,这样 g n k g_n^k gnk 就与 ω n k \omega_n^k ωnk 长得更像了…

但是,必须在 g n k g_n^k gnk 满足与 ω n k \omega_n^k ωnk 同样的性质时,我们才能等价替换。

现在,我们检验原根在模意义下是否满足与单位复根同样的性质:

  1. 由幂的运算立即可得
  2. 由幂的运算立即可得
  3. g 2 n k + n = ( G p − 1 2 n ) k + n = ( G p − 1 2 n ) k ∗ ( G p − 1 2 n ) n = G p − 1 2 ∗ g 2 n k = − g 2 n k ( m o d   p ) g_{2n}^{k+n}=(G^{\frac{p-1}{2n}})^{k+n}=(G^{\frac{p-1}{2n}})^k*(G^{\frac{p-1}{2n}})^n=G^{\frac{p-1}{2}}*g_{2n}^k=-g_{2n}^k(mod\space p) g2nk+n=(G2np1)k+n=(G2np1)k(G2np1)n=G2p1g2nk=g2nk(mod p) ,因为 ( G p − 1 = 1 ( m o d   p ) (G^{p-1}=1(mod\space p) (Gp1=1(mod p) 且由原根定义 G p − 1 2 ≠ G p − 1 ( m o d   p ) G^{\frac{p-1}{2}}\not=G^{p-1}(mod\space p) G2p1=Gp1(mod p) ,故 G p − 1 2 = − 1 ( m o d   p ) G^{\frac{p-1}{2}}=-1(mod\space p) G2p1=1(mod p)
  4. 由原根的定义立即可得

发现原根可以在模意义下 完全替换 单位复根。

这就是 N T T NTT NTT 了。

但是,这样的方法对模数会有一定的限制

m = 2 p ∗ k + 1 m = 2^p*k+1 m=2pk+1 k k k 为奇数,则多项式长度必须 n ≤ 2 p n \le 2^p n2p

至于模数以及其原根,没有必要来死记,为什么?

我们程序员就应该干我们经常干的事情——打表可得…

以下是参考代码:

#include
#include
using namespace std;

#define rep(i,__l,__r) for(register int i=__l,i##_end_=__r;i<=i##_end_;++i)
#define fep(i,__l,__r) for(register int i=__l,i##_end_=__r;i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define pii pair
#define Endl putchar('\n')
// #define FILEOI
// #define int long long

#ifdef FILEOI
    #define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
    #undef MAXBUFFERSIZE
    #define cg (c=fgetc())
#else
    #define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

const int MAXN=3e6;
const int MOD=998244353,g=3,gi=332748118;
int n,m;
int a[MAXN+5],b[MAXN+5],revi[MAXN+5];
inline int qkpow(int a,int x){
    int ret=1;
    for(;x>0;x>>=1){
         if(x&1)ret=1ll*ret*a%MOD;
        a=1ll*a*a%MOD;
    }
    return ret;
}
inline void ntt(int* f,const short opt=1){
    for(int i=0;i<n;++i)if(i<revi[i])swap(f[i],f[revi[i]]);
    for(int p=2,len,gn,Pow,tmp;p<=n;p<<=1){
        len=p>>1,gn=qkpow(opt==1?g:gi,(MOD-1)/p);
        for(int k=0;k<n;k+=p){Pow=1;
            for(int l=k;l<k+len;++l,Pow=1ll*Pow*gn%MOD){
                tmp=1ll*Pow*f[len+l]%MOD;
                if(f[l]-tmp<0)f[len+l]=f[l]-tmp+MOD;
                else f[len+l]=f[l]-tmp;
                if(f[l]-MOD+tmp>0)f[l]=f[l]-MOD+tmp;
                else f[l]+=tmp;
            }
        }
    }
    if(opt==-1){
        int inv=qkpow(n,MOD-2);
        for(int i=0;i<n;++i)f[i]=1ll*f[i]*inv%MOD;
    }
}
inline void launch(){
    qread(n,m);
    rep(i,0,n)qread(a[i]);
    rep(i,0,m)qread(b[i]);
    for(m+=n,n=1;n<=m;n<<=1);
    for(int i=0;i<n;++i)revi[i]=(revi[i>>1]>>1)|((i&1)?n>>1:0);
    ntt(a),ntt(b);
    for(int i=0;i<n;++i)a[i]=1ll*a[i]*b[i]%MOD;
    ntt(a,-1);
    rep(i,0,m)writc(a[i],' ');
    Endl;
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    launch();
    return 0;
}

一些引申问题及解决方法

假如题目中规定了模数怎么办?还卡 FFT 的精度怎么办?

有两种方法:三模数 NTT 以及 拆系数 FFT (MTT)

三模数 NTT

我们可以选取三个适用于 N T T NTT NTT 的模数 M 1 , M 2 , M 3 M1,M2,M3 M1M2M3 进行 N T T NTT NTT ,用中国剩余定理合并得到 x   m o d   ( M 1 ∗ M 2 ∗ M 3 ) x\space mod\space (M1*M2*M3) x mod (M1M2M3) 的值。只要保证 x < M 1 ∗ M 2 ∗ M 3 x < M1*M2*M3 x<M1M2M3 就可以直接输出这个值。

之所以是三模数,因为用三个大小在 1 0 9 10^9 109 左右模数对于大部分题目来说就足够了。

但是 M 1 ∗ M 2 ∗ M 3 M1*M2*M3 M1M2M3 可能非常大怎么办呢?难不成我还要写高精度?其实也可以。

我们列出同余方程组:

{ x ≡ a 1 m o d    m 1 x ≡ a 2 m o d    m 2 x ≡ a 3 m o d    m 3 \begin{cases} x \equiv a_1&\mod m_1\\ x \equiv a_2&\mod m_2\\ x \equiv a_3&\mod m_3\\ \end{cases} xa1xa2xa3modm1modm2modm3

用中国剩余定理合并前两个方程组,得到:
{ x ≡ A m o d    M x ≡ a 3 m o d    m 3 \begin{cases} x \equiv A&\mod M\\ x \equiv a_3&\mod m_3\\ \end{cases} {xAxa3modMmodm3
其中的 M M M 满足 M = m 1 ∗ m 2 < 1 0 18 M = m1*m2 < 10^{18} M=m1m2<1018

然后将第一个方程变形得到 x = k M + A x = kM + A x=kM+A ,代入第二个方程,得到:
k M + A ≡ a 3 m o d    m 3 k ≡ ( a 3 − A ) ∗ M − 1 m o d    m 3 kM+A \equiv a_3\mod m_3\\ k \equiv (a_3-A)*M^{-1} \mod m_3\\ kM+Aa3modm3k(a3A)M1modm3
Q = ( a 3 − A ) ∗ M − 1 Q = (a_3-A)*M^{-1} Q=(a3A)M1 ,则 k = P m 3 + Q k = Pm_3 + Q k=Pm3+Q

再将上式代入回 x = k M + A x = kM + A x=kM+A ,得 x = ( P m 3 + Q ) M + A = P m 3 M + Q M + A x = (Pm_3 + Q)M+ A = Pm_3M+QM+A x=(Pm3+Q)M+A=Pm3M+QM+A

又因为 M = m 1 m 2 M = m_1m_2 M=m1m2 ,所以 x = P m 1 m 2 m 3 + Q M + A x = Pm_1m_2m_3 + QM + A x=Pm1m2m3+QM+A

也就是说 x ≡ Q M + A m o d    m 1 m 2 m 3 x \equiv QM + A \mod m_1m_2m_3 xQM+Amodm1m2m3

然后,我们完美地解决了这个东西。

接下来是代码:

#include
#include
using namespace std;

#define rep(i,__l,__r) for(register int i=__l,i##_end_=__r;i<=i##_end_;++i)
#define fep(i,__l,__r) for(register int i=__l,i##_end_=__r;i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define pii pair
#define Endl putchar('\n')
// #define FILEOI
#define int long long

#ifdef FILEOI
    #define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
    #undef MAXBUFFERSIZE
    #define cg (c=fgetc())
#else
    #define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

inline int qkpow(int a,int n,const int mod){
    int ret=1;
    for(;n>0;n>>=1){
        if(n&1)ret=ret*a%mod;
        a=a*a%mod;
    }
    return ret;
}

const int MAXN=3e5;
const int MOD[3]={469762049ll,998244353ll,1004535809ll};//三模数
const int G=3;//共用的原根
int inv[3][3],k1,k2,Inv,M;
//inv[i][j] : MOD[i] 在 (mod MOD[j]) 下的逆元

int h[3][MAXN+5],g[3][MAXN+5];
//h/g[i][j] : 原函数的第 j 位在 (mod MOD[i]) 的情况下的值

int revi[MAXN+5];//反转数组

inline void init(){
    rep(i,0,2)rep(j,0,2)if(i!=j)//处理 inv 数组, 主要用到费马小定理
        inv[i][j]=qkpow(MOD[i],MOD[j]-2,MOD[j]);
    M=MOD[0]*MOD[1];
    k1=mulMod(MOD[1],inv[1][0],M);
    k2=mulMod(MOD[0],inv[0][1],M);
    Inv=inv[0][2]*inv[1][2]%MOD[2];
}

inline int crt(const int a1,const int a2,const int a3,const int mod){
    int A=(mulMod(a1,k1,M)+mulMod(a2,k2,M))%M;
    int K=(a3+MOD[2]-A%MOD[2])%MOD[2]*Inv%MOD[2];
    return ((M%mod)*K%mod+A)%mod;
}

inline void ntt(int* f,const int n,const int m,const short opt=1){
    /*
        和普通的 ntt 没啥区别, 如果有什么问题, 可以去查查 fft 的资料
        唯一有区别的地方在于取模的时候, 要根据我们目前计算的模数下的运算来取模
    */
    for(int i=0;i<n;++i)if(i<revi[i])swap(f[i],f[revi[i]]);
    for(int s=2;s<=n;s<<=1){
        int t=s>>1,u=(opt==-1)?qkpow(G,(MOD[m]-1)/s,MOD[m]):qkpow(G,MOD[m]-1-(MOD[m]-1)/s,MOD[m]);
        for(int i=0;i<n;i+=s){int w=1;
            for(int j=i;j<i+t;++j,w=w*u%MOD[m]){
                int x=f[j],y=w*f[j+t]%MOD[m];
                f[j]=(x+y)%MOD[m];
                f[j+t]=(x-y+MOD[m])%MOD[m];
            }
        }
    }
    if(opt==-1){
        int inv=qkpow(n,MOD[m]-2,MOD[m]);
        rep(i,0,n-1)f[i]=f[i]*inv%MOD[m];
    }
}

int n,m,p;

inline void launch(){
    init();
    qread(n,m,p);
    rep(i,0,n){//这里我输入的最大的一个模数, 因为其已经超过 1e9 的范围, 刚刚输入时不用取模
        qread(h[2][i]);
        h[1][i]=h[2][i]%MOD[1];
        h[0][i]=h[2][i]%MOD[0];
    }
    rep(i,0,m){
        qread(g[2][i]);
        g[1][i]=g[2][i]%MOD[1];
        g[0][i]=g[2][i]%MOD[0];
    }
    for(m+=n,n=1;n<=m;n<<=1);
    for(int i=0;i<n;++i)revi[i]=(revi[i>>1]>>1)|((i&1)?n>>1:0);
    rep(i,0,2){
        ntt(h[i],n,i),ntt(g[i],n,i);
        rep(j,0,n-1)h[i][j]=h[i][j]*g[i][j]%MOD[i];
        ntt(h[i],n,i,-1);
    }
    for(int i=0;i<=m;++i)
        writc(crt(h[0][i],h[1][i],h[2][i],p),' ');
    //使用 crt(我国剩余定理) 来还原答案
    // rep(i,0,m)printf("%lld %lld %lld\n",h[0][i],h[1][i],h[2][i]);
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    launch();
    return 0;
}

拆系数 FFT (MTT)

我太菜了,还不会…等我更新吧…

你可能感兴趣的:(fft)