洛谷 P4245 【模板】MTT(三模数NTT)

题目の传送门

https://www.luogu.org/problemnew/show/P4245

思路

模数任意的多项式乘法。本题有两种做法,一种是拆系数FFT,另一种就是我写的跑得炒鸡慢的三模数NTT

首先相乘后每一位可能达到 NP2 1023 那么大。我们找三个NTT模数使其乘积大于 1023 ,然后暴力合并是不行的,因为会爆long long,我们先暴力合并两个,然后用奇技淫巧合并第三个就行了。

具体的做法,我是参考了KsCla大佬的博客的,写的真的很好(特别是下面的链接),大家快去膜他,我就不赘述了。

蒟蒻我有一条方程组的解不会证明(虽然某大佬们都说这是显然的),希望有大佬能告诉我怎么证:

若所有 mi 互质,且满足:
n1+n20(modm3)
n1+n30(modm2)
n2+n30(modm1)

则有
n10(modm2m3)
n20(modm1m3)
n30(modm1m2)

关于CRT合并一个同余方程组,合并后的那条方程和原来的方程组不是等价的吗?目前我只能从在[0,M)内有唯一解,通解是X+k*M去理解。

回到这题,取模的时候一定要认真膜啊,不然变成负数就不知所措了。

代码

#include 
#define maxn 1000100

using namespace std;
typedef long long LL;

const int M1 = 998244353, M2 = 1004535809, M3 = 469762049;
int n, n1, n2, P;
int Rev[maxn];
LL a[3][maxn], b[3][maxn], ans[maxn];

void Init(int x){
    int L = 0;
    for(n = 1; n < x; n <<= 1, L++);  
    for(int i = 0; i < n; i++)  Rev[i] = (Rev[i>>1]>>1) | ((i&1)<<(L-1));
}

LL Mul(LL x, int y, LL MOD){
    LL res = 0;
    for(; y; x = (x << 1) % MOD, y >>= 1)
        if(y & 1)  res = (res + x) % MOD;
    return res;
}

LL Pow(LL x, int y, int MOD){
    LL res = 1;
    for(; y; x = x * x % MOD, y >>= 1)
        if(y & 1)  res = res * x % MOD;
    return res;
}

void NTT(LL *A, int DFT, int MOD){
    for(int i = 0; i < n; i++)  if(i < Rev[i])  swap(A[i], A[Rev[i]]);

    for(int s = 1; (1<int m = 1 << s;
        LL wn = Pow(3, (DFT == 1) ? (MOD-1)/m : (MOD-1)-(MOD-1)/m, MOD);
        for(int k = 0; k < n; k += m){
            LL w = 1;
            for(int j = 0; j < (m>>1); j++){
                LL u = A[k+j], t = w * A[k+j+(m>>1)] % MOD;
                A[k+j] = (u+t) % MOD;
                A[k+j+(m>>1)] = (u-t+MOD) % MOD;
                w = w * wn % MOD;
            }
        }
    }
    if(DFT == -1){
        LL Inv = Pow(n, MOD-2, MOD);
        for(int i = 0; i < n; i++)  A[i] = A[i] * Inv % MOD;
    }
}

void Work(int x, int MOD){
    NTT(a[x], 1, MOD);
    NTT(b[x], 1, MOD);

    for(int i = 0; i < n; i++)  a[x][i] = a[x][i] * b[x][i] % MOD;

    NTT(a[x], -1, MOD);
}

void CRT(){
    LL M = 1LL * M1 * M2;
    for(int i = 0; i < n; i++){
        LL temp = 0;
        temp = (temp + Mul(a[0][i] * M2 % M, Pow(M2, M1-2, M1), M)) % M;
        temp = (temp + Mul(a[1][i] * M1 % M, Pow(M1, M2-2, M2), M)) % M;
        a[1][i] = temp;
    }
    for(int i = 0; i < n; i++){
        LL temp = (a[2][i] - a[1][i] % M3 + M3) % M3 * Pow(M%M3, M3-2, M3) % M3;
        ans[i] = (M % P * temp % P + a[1][i] % P) % P;
    }
}

int main(){

    scanf("%d%d%d", &n1, &n2, &P);

    n1 ++;  n2 ++;

    LL x;
    for(int i = 0; i < n1; i++){  
        scanf("%lld", &x);
        for(int j = 0; j < 3; j++)  a[j][i] = x % P;
    }
    for(int i = 0; i < n2; i++){
        scanf("%lld", &x);
        for(int j = 0; j < 3; j++)  b[j][i] = x % P;
    }

    Init(n1 + n2);

    Work(0, M1);
    Work(1, M2);
    Work(2, M3);

    CRT();

    for(int i = 0; i < n1+n2-1; i++)  printf("%lld ", ans[i]);

    return 0;
}

你可能感兴趣的:(数论,&,数学,FFT,&,NTT,中国剩余定理)