Luogu P3803 多项式乘法(ntt版)

文章目录

  • 思路
  • NTT
  • 原根
  • NTT
  • code
    • 版本1:FFT
    • 版本2:NTT
    • 版本3(函数版NTT,带求乘法逆)
  • Thanks!

题目链接

思路

相信大家都已经学会了 F F T FFT FFT,若不会,请看这篇博客
一个同学写的,自认为不错
在我们的 F F T FFT FFT中,我们使用了复数来进行计算
但是我们可以发现复数的乘法时间复杂度是 O ( 4 ) O(4) O(4)
d o u b l e double double的计算则更加增添了时间复杂度
同时因为浮点数计算     \sqrt{\ \ \ }     的精度会有误差
导致我们最终的所有部分的和反而与完整的 360 360 360°
那么我们不妨试想,若是在系数为整数的情况下
或者需要取模的时候,我们该如何来解决呢?

NTT

于是我们就可以介绍今天的主角了
NTT —— 快速数论变换
是一种建立在数论上的对FFT的优化
(或者可以说是取模运算的FFT)
只不过由于FFT用到是复数
而且double在做了大量的实数运算之后
精度损失较大
而我们的NTT就可以在模意义下
快速做这样的一个多项式乘法
NTT常数小一些
一般这个模数被认为是 x ∗ 2 k + 1 x * 2^k+1 x2k+1

原根

接下来我们介绍一个东西——原根
m m m是正整数, a a a是整数
a a a m m m的阶等于 φ ( m ) \varphi(m) φ(m)
则称 a a a为模 m m m的一个原根。
假设一个数 g g g P P P的原根
那么 g i   m o d   P g^i\ mod\ P gi mod P的结果两两不同
且有 1 < g < P 11<g<P 0 < i < P 00<i<P
归根到底就是 g P − 1 ≡ 1 ( m o d   P ) g^{P-1} \equiv 1 (mod\ P) gP11(mod P)
当且仅当指数为 P − 1 P-1 P1的时候成立( P P P是素数)。
简单来说, g i   m o d   p ≠ g j m o d   p g^i\ mod\ p \neq g^j mod\ p gi mod p=gjmod p ( p p p为素数)
其中 i ≠ j i \ne j i=j i , j i, j i,j介于 1 1 1 ( p − 1 ) (p-1) (p1)之间
g g g p p p的原根。
提供一种暴力的原根的求法

int calc(int x)// 求原根
{
    if (x == 2)
        return 1;
    for (int i = 2;i ;i++)
    {
        bool f = 1;
        for (int j = 2;j * j < x; j++)
            if (qkpow(i, (x - 1) / j, x) == 1)
            {
                f = 0;
                break;
            }
        if (f)
            return i;
    }
}

下面是一个常用模数的表
转载自Rayment_cc大佬的博客
Luogu P3803 多项式乘法(ntt版)_第1张图片

NTT

而有了原根以后
我们就可以将复数替换成原根的 m o d − 1 2   ∗   i \frac{mod - 1}{2\ *\ i} 2  imod1次方
然后将FFT套上去
就可以实现了
至于为什么,请看这里
Luogu P3803 多项式乘法(ntt版)_第2张图片
于是我们就搞定了原根与单位复根的转换

code

版本1:FFT

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define pk putchar(' ')
#define ph puts("")
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
typedef complex<double> cpd;
template <class T>
void rd(T &x)
{
    x = 0;
    int f = 1;
    char c = getchar();
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
    x *= f;
}
template <class T>
void pt(T x)
{
    if (x < 0)
        putchar('-'), x = (~x) + 1;
    if (x > 9)
        pt(x / 10);
    putchar(x % 10 ^ 48);
}
template <class T>
T Max(T a, T b)
{
    return a > b ? a : b;
}
template <class T>
T Min(T a, T b)
{
    return a < b ? a : b;
}
template <class T>
T Fabs(T x)
{
    return x > 0 ? x : -x;
}
template <class T>
void Swap(T &x, T &y)
{
    T t = x;
    x = y;
    y = t;
}
const int N = (1 << 18) + 5;
const double pi = acos(-1.0);
int n, R[N], l, lim, m;
cpd a[N], b[N];
void fft(cpd *a, int opt)
// opt = 1 => dft
// opt = -1 => idft
{
    for (int i = 0; i < lim; i++)
        if (i < R[i])
            Swap(a[i], a[R[i]]);
    for (int i = 1;i < lim; i <<= 1)
    {
        cpd x{cos(pi / i), opt * sin(pi / i)}, y{1, 0};
        for (int j = 0;j < lim; j += (i << 1), y = {1, 0})
            for (int k = 0;k < i; k++, y *= x)
            {
                cpd p = a[j + k], q = y * a[j + k + i];
                a[j + k] = p + q;
                a[j + k + i] = p - q;
            }
    }
}
// void ptcp(cpd x)
// {
//     printf ("%f+%fi\n", x.real(), x.imag());
//     return;
// }
int main()
{
    rd(n), rd(m);
    for (int i = 0, x;i <= n; i++)
        rd(x), a[i] = x;
    for (int i = 0, x;i <= m; i++)
        rd(x), b[i] = x;
    for (lim = 1;lim <= n + m; lim <<= 1) 
        l++;
    for (int i = 0;i < lim; i++)
        R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1));
    fft(a, 1);
    fft(b, 1);
    for (int i = 0;i <= lim; i++)
        a[i] = a[i] * b[i];
    fft(a, -1);
    for (int i = 0;i <= n + m; i++)
        pt(int((a[i].real() + 0.5) / lim)), pk;
    return 0;
}

版本2:NTT

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define pk putchar(' ')
#define ph puts("")
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
template <class T>
void rd(T &x)
{
    x = 0;
    int f = 1;
    char c = getchar();
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
    x *= f;
}
template <class T>
void pt(T x)
{
    if (x < 0)
        putchar('-'), x = (~x) + 1;
    if (x > 9)
        pt(x / 10);
    putchar(x % 10 ^ 48);
}
template <class T>
T Max(T a, T b)
{
    return a > b ? a : b;
}
template <class T>
T Min(T a, T b)
{
    return a < b ? a : b;
}
template <class T>
T Fabs(T x)
{
    return x > 0 ? x : -x;
}
template <class T>
void Swap(T &x, T &y)
{
    T t = x;
    x = y;
    y = t;
}
const int N = (1 << 18) + 5, mod = 998244353;
int n, R[N], l, lim, m, a[N], b[N], g;
int qkpow(int x, int y, int mod)
{
    int res = 1;
    while (y)
    {
        if (y & 1)
            res = 1ll * res * x % mod;
        y >>= 1;
        x = 1ll * x * x % mod;
    }
    return res;
}
int calc(int x)// 求原根
{
    if (x == 2)
        return 1;
    for (int i = 2;i ;i++)
    {
        bool f = 1;
        for (int j = 2;j * j < x; j++)
            if (qkpow(i, (x - 1) / j, x) == 1)
            {
                f = 0;
                break;
            }
        if (f)
            return i;
    }
}
void ntt(int *a, int opt)
// opt = 1 => dft
// opt = -1 => idft
{
    for (int i = 0; i < lim; i++)
        if (i < R[i])
            Swap(a[i], a[R[i]]);
    for (int i = 1;i < lim; i <<= 1)
    {
        int tmp = qkpow(g, (mod - 1) / (i << 1), mod);
        if (opt == -1)
            tmp = qkpow(tmp, mod - 2, mod);
        for (int j = 0, y = 1;j < lim; j += (i << 1), y = 1)
            for (int k = 0;k < i; k++, y = 1ll * y * tmp % mod)
            {
                int p = a[j + k], q = 1ll * y * a[j + k + i] % mod;
                a[j + k] = (p + q) % mod;
                a[j + k + i] = (p - q + mod) % mod;
            }
    }
    if (opt == -1)
    {
        int inv = qkpow(lim, mod - 2, mod);
        for (int i = 0;i < lim; i++)
            a[i] = 1ll * inv * a[i] % mod;
    }
}
int main()
{
    rd(n), rd(m);
    for (int i = 0;i <= n; i++)
        rd(a[i]);
    for (int i = 0;i <= m; i++)
        rd(b[i]);
    for (lim = 1;lim <= n + m; lim <<= 1) 
        l++;
    for (int i = 0;i < lim; i++)
        R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1));
    g = calc(mod);
    ntt(a, 1);
    ntt(b, 1);
    for (int i = 0;i <= lim; i++)
        a[i] = 1ll * a[i] * b[i] % mod;
    ntt(a, -1);
    for (int i = 0;i <= n + m; i++)
        pt(a[i]), pk;
    return 0;
}

版本3(函数版NTT,带求乘法逆)

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define pk putchar(' ')
#define ph puts("")
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
template <class T>
void rd(T &x)
{
    x = 0;
    ll B = 1;
    char c = getchar();
    while (!isdigit(c)) {if (c == '-') B = -1; c = getchar();}
    while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
    x *= B;
}
template <class T>
void pt(T x)
{
    if (x < 0)
        putchar('-'), x = (~x) + 1;
    if (x > 9)
        pt(x / 10);
    putchar(x % 10 ^ 48);
}
template <class T>
T Max(T a, T b)
{
    return a > b ? a : b;
}
template <class T>
T Min(T a, T b)
{
    return a < b ? a : b;
}
template <class T>
T Fabs(T x)
{
    return x > 0 ? x : -x;
}
template <class T>
void Swap(T &x, T &y)
{
    T t = x;
    x = y;
    y = t;
}
const int N = (1 << 19) + 5, mod = 998244353, p = 3; 
int n, R[N], c[N], a[N], b[N], m;
int qkpow(int x, int y, int mod)
{
    int res = 1;
    while (y)
    {
        if (y & 1)
            res = 1ll * res * x % mod;
        y >>= 1;
        x = 1ll * x * x % mod;
    }
    return res;
}
void ntt(int *a, int opt, int lim, int mod)
// opt = 1 => dft
// opt = -1 => idft
{
    int p_inv = opt == -1 ? qkpow(p, mod - 2, mod) : 0;
    for (int i = 0; i < lim; i++)
        if (i < R[i])
            Swap(a[i], a[R[i]]);
    for (int i = 1;i < lim; i <<= 1)
    {
        int tmp = qkpow(opt == 1 ? p : p_inv, (mod - 1) / (i << 1), mod);
        for (int j = 0, y = 1;j < lim; j += (i << 1), y = 1)
            for (int k = 0;k < i; k++, y = 1ll * y * tmp % mod)
            {
                int p = a[j + k], q = 1ll * y * a[j + k + i] % mod;
                a[j + k] = (p + q) % mod;
                a[j + k + i] = (p - q + mod) % mod;
            }
    }
    if (opt == -1)
    {
        int inv = qkpow(lim, mod - 2, mod);
        for (ll i = 0;i < lim; i++)
            a[i] = 1ll * inv * a[i] % mod;
    }
}
void mul(int *a, int *b, int n, int m)
{
    int lim, l;
    for (l = 0, lim = 1;lim < n + m; l++, lim <<= 1);
    for (int i = 0;i < lim; R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1)), i++);
    ntt(a, 1, lim, mod);
    ntt(b, 1, lim, mod);
    for (int i = 0;i < lim; i++)
        a[i] = 1ll * a[i] * b[i] % mod;
    ntt(a, -1, lim, mod);
}
void inv(int *a, int *b, int siz)
{
    if (siz == 1)
    {
        b[0] = qkpow(a[0], mod - 2, mod);
        return;
    }
    inv(a, b, (siz + 1) >> 1);
    int lim, l;
    for (l = 0, lim = 1;lim < siz + siz; l++, lim <<= 1);
    for (int i = 0;i < lim; R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1)), i++);
    for (int i = 0;i < siz; i++)
        c[i] = a[i];
    for (int i = siz;i < lim; i++)
        b[i] = 0;
    ntt(c, 1, lim, mod);
    ntt(b, 1, lim, mod);
    for (int i = 0;i < lim; i++)
        b[i] = (b[i] * 2 % mod - 1ll * b[i] * b[i] % mod * c[i] % mod + mod) % mod;
    ntt(b, -1, lim, mod);
    for (int i = siz;i < lim; i++)
        b[i] = 0;
}
int main()
{
    // freopen("testdata.in", "r", stdin);
    // freopen("test.out", "w", stdout);
    rd(n), rd(m);
    n++, m++;
    for (int i = 0;i < n; i++)
        rd(a[i]);
    for (int j = 0;j < m; j++)
        rd(b[j]);
    mul(a, b, n, m);
    for (int i = 0;i < n + m - 1; i++)
        pt(a[i]), pk;
    return 0;
}

不要吝啬您的小鼠标,右上角的小手手,请点击下

Thanks!

你可能感兴趣的:(c++)