任意模数FFT模板(一大一小模数NTT)

练手题:
51nod 1172 Partial Sums V2

这个一大一小模数NTT可以比三模数NTT少三次DFT,但是有三次DFT常数会大,因为用了大数乘法黑科技,整体还是要快的。

两个模数分别是:
998244353和29*2^57+1,原根都是3

用中国剩余定理合并的时候,只能两两合并,类似于扩展CRT,推推式子就行了。

Code:

#include
#include
#define ll long long
#define ld long double
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define ff(i, x, y) for(int i = x; i < y; i ++)
using namespace std;

const int mo = 1e9 + 7, mo1 = 998244353;


const int N = 5e4 + 5;

int tp; ll w[N * 4];

ll a[N * 4], b[N * 4];

ll ksm(ll x, ll y, const ll mo) {
    ll s = 1;
    for(; y; y /= 2, x = x * x % mo)
        if(y & 1) s = s * x % mo;
    return s;
}
void dft1(ll *a, int n, const ll mo) {
    ff(i, 0, n) {
        int p = 0, q = i;
        fo(j, 1, tp) p = p * 2 + (q & 1), q >>= 1;
        if(i > p) swap(a[i], a[p]);
    } ll A;
    for(int h = 1, g = 0, m; (m = 1 << ++ g) <= n; h = m)
        for(int j = 0; j < n; j += m) ff(i, 0, h)
            A = a[i + j + h] * w[i * (n >> g)], a[i + j + h] = (a[i + j] - A) % mo, a[i + j] = (a[i + j] + A) % mo;
}
void fft1(ll *a, ll *b, int n, const ll mo) {
    w[0] = 1; ll v = ksm(3, (mo - 1) / n, mo);
    fo(i, 1, n) w[i] = w[i - 1] * v % mo;
    dft1(a, n, mo); dft1(b, n, mo);
    ff(i, 0, n) a[i] = a[i] * b[i] % mo;
    fo(i, 1, n / 2) swap(w[i], w[n - i]);
    dft1(a, n, mo); v = ksm(n, mo - 2, mo);
    ff(i, 0, n) a[i] = (a[i] + mo) * v % mo;
}

const ll mo2 = (1LL << 57) * 29 + 1;
ll c[N * 4], d[N * 4];
ll mul(ll x, ll y) {
    ll z = (ld) x * y / mo2 ; z = x * y - z * mo2;
    if(z < 0) z += mo2; else if(z > mo2) z -= mo2;
    return z;
}
ll ksm2(ll x, ll y, const ll mo) {
    ll s = 1;
    for(; y; y /= 2, x = mul(x, x))
        if(y & 1) s = mul(s, x);
    return s;
}
void dft2(ll *a, int n, const ll mo) {
    ff(i, 0, n) {
        int p = 0, q = i;
        fo(j, 1, tp) p = p * 2 + (q & 1), q >>= 1;
        if(i > p) swap(a[i], a[p]);
    } ll A;
    for(int h = 1, g = 0, m; (m = 1 << ++ g) <= n; h = m)
        for(int j = 0; j < n; j += m) ff(i, 0, h)
            A = mul(a[i + j + h], w[i * (n >> g)]), a[i + j + h] = (a[i + j] - A + mo) % mo, a[i + j] = (a[i + j] + A) % mo;
}
void fft2(ll *a, ll *b, int n, const ll mo) {
    w[0] = 1; ll v = ksm2(3, (mo - 1) / n, mo);
    fo(i, 1, n) w[i] = mul(w[i - 1], v);
    dft2(a, n, mo); dft2(b, n, mo);
    ff(i, 0, n) a[i] = mul(a[i], b[i]);
    fo(i, 1, n / 2) swap(w[i], w[n - i]);
    dft2(a, n, mo); v = ksm2(n, mo - 2, mo);
    ff(i, 0, n) a[i] = mul(a[i], v);
}

int n, k, m;

int main() {
    scanf("%d %d", &n, &k);
    fo(i, 1, n) scanf("%d", &a[i]);
    b[0] = 1;
    fo(i, 1, n - 1) b[i] = b[i - 1] * (i + k - 1) % mo * ksm(i, mo - 2, mo) % mo;
    while(1 << tp ++ <= n);
    fo(i, 0, n - 1) d[i] = b[i];
    fo(i, 1, n) c[i] = a[i];
    fft1(a, b, 1 << tp, mo1); fft2(c, d, 1 << tp, mo2);
    fo(i, 1, n) {
        ll j = ((a[i] - c[i]) % mo1 + mo1) % mo1;
        printf("%lld\n", (j * ksm(mo2 % mo1, mo1 - 2, mo1) % mo1 * (mo2 % mo) + c[i]) % mo);
    }
}

你可能感兴趣的:(模版,FFT,NTT,FWT……)