BZOJ 3992: [SDOI2015]序列统计

可以列出 dp 方程
\(dp[l][i] = \sum\limits_{j*k\equiv i \pmod M}dp[l-1][j]*c[k]\)
乘积的形式不好卷,考虑离散对数
\(M\) 存在原根 \(g\),那么 \(g^0,g^1,\dots, g^{M-2}\) 就可以表示 \(M-1\) 个数,集合中恰好有 \(M-1\) 个数。
\(j*k\equiv i \pmod M\),就变成了 \(g^{j'}*g^{k'}\equiv g^{i'} \pmod M\),也就是 \(j'+k'\equiv i' \pmod {M-1}\)
这样就可以卷积了
\(A(x) = C(x)^k\)
用快速幂加NTT就可了,注意做完一次卷积要把 \(M-1\) ~ \(2M - 3\) 部分加到 \(0\) ~ \(M - 2\) 上。

#include 

const int MOD = 1004535809, N = 1e5 + 7;

inline void M(int &x) {
    if (x >= MOD) x -= MOD;
    if (x < 0) x += MOD;
}

int qp(int a, int b = MOD - 2, int mod = MOD) {
    int ans = 1;
    for ( ; b; b >>= 1, a = 1LL * a * a % mod)
        if (b & 1)
            ans = 1LL * ans * a % mod;
    return ans % mod;
}

int mul(int a, int b) {
    return 1LL * a * b % MOD;
}

namespace FFT {
    const int G = 3;
    int n, l, r[N];
    void init(int m) {
        n = 1, l = 0;
        while (n <= 2 * m) n <<= 1, l++;
        for (int i = 0; i < n; i++)
            r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
    }
    void NTT(int a[], int pd) {
        for (int i = 0; i < n; i++)
            if (i < r[i])
                std::swap(a[i], a[r[i]]);
        for (int mid = 1; mid < n; mid <<= 1) {
            int l = mid << 1;
            int wn = qp(G, (MOD - 1) / l);
            if (pd == -1)
                wn = qp(wn, MOD - 2);
            for (int j = 0; j < n; j += l) {
                int w = 1;
                for (int k = 0; k < mid; k++, w = 1LL * w * wn % MOD) {
                    int u = a[j + k], v = mul(w, a[k + j + mid]);
                    a[k + j] = (u + v) % MOD;
                    a[k + j + mid] = (u - v + MOD) % MOD;
                }
            }
        }
        if (pd == -1) {
            int inv = qp(n, MOD - 2);
            for (int i = 0; i < n; i++)
                a[i] = 1LL * a[i] * inv % MOD;
        }
    }
    void mul(int *a, int *b, int m) {
        static int c[N], d[N];
        for (int i = 0; i < n; i++)
            c[i] = i < m ? a[i] % MOD : 0, d[i] = i < m ? b[i] % MOD : 0;
        NTT(c, 1); NTT(d, 1);
        for (int i = 0; i < n; i++)
            c[i] = 1LL * c[i] * d[i] % MOD;
        NTT(c, -1);
        for (int i = 0; i < n; i++) {
            if (i < m) M(a[i] = c[i] + c[i + m]);
            else a[i] = 0;
        }
    }
}

int cal(int n) {
    if (n == 2) return 1;
    static int p[N];
    int cnt = 0, x = n - 1;
    for (int i = 2; i * i <= x; i++) {
        if (x % i == 0) {
            p[++cnt] = i;
            while (x % i == 0)
                x /= i;
        }
    }
    if (x != 1) p[++cnt] = x;
    for (int i = 2; i < n; i++) {
        bool flag = 1;
        for (int j = 1; j <= cnt; j++)
            if (qp(i, (n - 1) / p[j], n) == 1) {
                flag = 0;
                break;
            }
        if (flag) return i;
    }
    return -1;
}

void poly_pow(int *a, int *b, int n, int k) {
    b[0] = 1;
    while (k) {
        if (k & 1) FFT::mul(b, a, n);
        FFT::mul(a, a, n);
        k >>= 1;
    }
}

bool vis[N];
int n, X, S, a[N], b[N];

int main() {
    int M;
    scanf("%d%d%d%d", &n, &M, &X, &S);
    int res = 0;
    for (int i = 1; i <= S; i++) {
        int x;
        scanf("%d", &x);
        vis[x] = 1;
    }
    int g = cal(M);
    if (g == -1) {
        puts("0");
        return 0;
    }
    for (int p = 1, i = 0; i < M - 1; i++, p = 1LL * g * p % M) {
        if (vis[p]) a[i] = 1;
        if (p == X) res = i;
    }
    FFT::init(M - 1);
    poly_pow(a, b, M - 1, n);
    printf("%d\n", b[res]);
    return 0;
}

你可能感兴趣的:(BZOJ 3992: [SDOI2015]序列统计)