NTT学习笔记

NTT学习笔记

前言

FFT

  • 我们知道\(FFT\)可以快速的完成两个多项式的乘法,利用了单位复根的特殊性质。
  • 由于复数的实部与虚部是正余弦函数,需要做浮点数运算,以及产生误差。
  • 这样计算量比较大,而且复数不可以取模。

NTT

  • 中文名:快速数论变换。

  • 多项式乘法有时候会建立在模域,对一些特殊的大质数取模时,可以考虑用原根\(g\)来代替,而这些特殊的大质数的原根恰好满足了某些性质,使得多项式乘法在模域中也可以快速的分治合并。

前置知识

  • \(a,p\)互质,且\(p>1\)
  • 对于\(a^n\equiv 1(mod\ p)\)最小的\(n\),我们成为\(a\)\(p\)的阶,记做\(\delta_p(a)\)
  • 例如:\(\delta_7(2)=3\)

原根

  • \(p\)是正整数,\(a\)是整数,若\(\delta_p(a)\)等于\(\varphi(p)\),则称\(a\)为模\(p\)的一个原根。
  • 比如说\(\delta_7(3)=6=\varphi(7)\),因此\(3\)是模\(7\)的一个原根。
  • 重要定理:(其实只要知道这个就行了)

  • 对于\(g,p\in Z\),如果\(g^i\ mod\ p(1\leq i\leq p-1)\)的值互不相同,则称\(g\)\(p\)的原根。
  • 常见的模数有\(998244353,1004535809,469762049\),这几个数的原根都是\(3(g=3)\)

NTT

  • \(FFT\)能够大大优化多项式乘法是因为单位复根有特殊且优秀的性质。
  • 原根也有。
  • \(NTT\)中,用原根来代替\(FFT\)中的单位复根。
  • 任意模数\(NTT\)以后再说。

洛谷3803:多项式乘法

代码和FFT挺像的。

#include
using namespace std;
typedef long long ll;
templateinline void read(T &x){
    x=0;
    static int p;p=1;
    static char c;c=getchar();
    while(!isdigit(c)){if(c=='-')p=-1;c=getchar();}
    while(isdigit(c)) {x=(x<<1)+(x<<3)+(c-48);c=getchar();}
    x*=p;
}

const int maxn = 5e6 + 10;
const int mod = 998244353;
int n, m, a[maxn], b[maxn], limit=1, bit;
int rev[maxn];

ll qmi(ll a, ll b)
{
    ll res = 1; res %= mod;
    while(b)
    {
        if(b&1) res = (res*a) % mod;
        b >>= 1;
        a = (a*a)%mod;
    } return res%mod;
}

void NTT(int c[], int op)
{
    for(int i = 0; i < limit; i++)
        if(i < rev[i]) swap(c[i], c[rev[i]]);
    for(int mid = 1; mid < limit; mid <<= 1)
    {
        ll gn = qmi(3, (mod-1)/(mid<<1));
        if(op == -1) gn = qmi(gn, mod-2);
        for(int j = 0, R = mid<<1; j < limit; j += R)
        {
            ll g = 1;
            for(int k = 0; k < mid; k++, g = (g*gn)%mod)
            {
                int x = c[j+k], y = g*c[j+k+mid]%mod;
                c[j+k] = (x+y)%mod;
                c[j+k+mid] = (x-y+mod)%mod;
            }
        }
    }
}

int main()
{
    read(n), read(m);
    for(int i = 0; i <= n; i++) read(a[i]);
    for(int i = 0; i <= m; i++) read(b[i]);

    limit = 1;
    while(limit <= n+m) limit <<= 1, bit++;
    for(int i = 0; i < limit; i++)
        rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1));

    NTT(a, 1); NTT(b, 1);
    for(int i = 0; i < limit; i++) a[i] = 1ll*a[i]*b[i]%mod;
    NTT(a, -1);
    ll inv = qmi(limit, mod-2);
    for(int i = 0; i <= n+m; i++)
        printf("%d ", (a[i]*inv)%mod);
    return 0;
}

你可能感兴趣的:(NTT学习笔记)