[HAOI2018]染色 题解

传送门

题目大意:一个长度为 n n n的颜色序列,至多 m m m种颜色,如果有 k k k种颜色恰好出现了 S S S次,愉悦度增加 w k w_k wk k = 0 , 1 , 2 , ⋯   , m k=0,1,2,\cdots,m k=0,1,2,,m)。求所有方案愉悦度之和。

首先颜色数量至多 L = min ⁡ ( m , ⌊ n S ⌋ ) L=\min\left(m,\lfloor\frac{n}{S}\rfloor\right) L=min(m,Sn)种。

f i f_i fi表示至少 i i i种颜色出现了恰好 S S S次的方案数
g i g_i gi表示恰好 i i i种颜色出现了恰好 S S S次的方案数

不难发现 f i f_i fi g i g_i gi具有容斥的关系,并且 a n s = ∑ i = 0 L g i w i ans=\sum\limits_{i=0}^Lg_iw_i ans=i=0Lgiwi

考虑 f i f_i fi:从 m m m种颜色中选出 i i i种( C m i C_m^i Cmi),从 n n n个位置中选出 i S iS iS个( C n i S C_n^{iS} CniS),强行让这 i i i种颜色每种都出现 S S S次,它们的排列数就是 ( i S ) ! ( S ! ) i \frac{(iS)!}{(S!)^i} (S!)i(iS)!,然后剩下 n − i S n-iS niS个位置的颜色从剩下 m − i m-i mi种中任选( ( m − i ) n − i S (m-i)^{n-iS} (mi)niS),所以有
f i = C m i C n i S ( i S ) ! ( S ! ) i ( m − i ) n − i S f_i=C_m^iC_n^{iS}\frac{(iS)!}{(S!)^i}(m-i)^{n-iS} fi=CmiCniS(S!)i(iS)!(mi)niS

f i f_i fi进行容斥即可得到 g i g_i gi

g i = ∑ j = i L ( − 1 ) j − i C j i f j g_i=\sum\limits_{j=i}^L(-1)^{j-i}C_j^if_j gi=j=iL(1)jiCjifj

g i = ∑ j = i L ( − 1 ) j − i j ! i ! ( j − i ) ! f j g_i=\sum\limits_{j=i}^L(-1)^{j-i}\frac{j!}{i!(j-i)!}f_j gi=j=iL(1)jii!(ji)!j!fj

g i i ! = ∑ j = i L ( − 1 ) j − i ( j − i ) ! j ! f j g_ii!=\sum\limits_{j=i}^L\frac{(-1)^{j-i}}{(j-i)!}j!f_j gii!=j=iL(ji)!(1)jij!fj

a i = ( − 1 ) i i ! , b i = i ! f i a_i=\frac{(-1)^i}{i!},b_i=i!f_i ai=i!(1)i,bi=i!fi,并且设 c i c_i ci b i b_i bi的翻转,那么

g i i ! = ∑ j = i L a j − i b j g_ii!=\sum\limits_{j=i}^La_{j-i}b_j gii!=j=iLajibj

= ∑ j = 0 L − i a j b i + j =\sum\limits_{j=0}^{L-i}a_jb_{i+j} =j=0Liajbi+j

= ∑ j = 0 L − i a j c L − i − j =\sum\limits_{j=0}^{L-i}a_jc_{L-i-j} =j=0LiajcLij

化为了卷积的形式,NTT即可解决。

#include 
#include 
#include 
#include 

template <typename T> inline void read(T& x) {
    int f = 0, c = getchar(); x = 0;
    while (!isdigit(c)) f |= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
    if (f) x = -x;
}
template <typename T, typename... Args>
inline void read(T& x, Args&... args) {
    read(x); read(args...); 
}
template <typename T> void write(T x) {
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }

typedef long long LL;

inline LL qpow(LL x, LL k, LL P = 1004535809) {
    LL s = 1;
    for (; k; x = x * x % P, k >>= 1)
        if (k & 1) s = s * x % P;
    return s;
}

const int maxm = 1e5 + 207;
const int maxn = 1e7 + 7;
const LL P = 1004535809, G = 3, Gi = qpow(G, P - 2);

int w[maxm], low, up, n, m, S;
int r[maxm << 2], lim, l;
LL a[maxm << 2], c[maxm << 2];
LL fac[maxn], ifac[maxn];

inline void ntt(LL *A, int tp) {
    for (int i = 0; i < lim; ++i)
        if (i < r[i]) std::swap(A[i], A[r[i]]);
    for (int mid = 1; mid < lim; mid <<= 1) {
        LL wn = qpow(tp == 1 ? G : Gi, (P - 1) / (mid << 1));
        for (int j = 0; j < lim; j += mid << 1) {
            LL w = 1;
            for (int k = 0; k < mid; ++k, w = w * wn % P) {
                LL x = A[j + k], y = w * A[j + k + mid] % P;
                A[j + k] = (x + y) % P;
                A[j + k + mid] = (x - y + P) % P;
            }
        }
    }
    if (tp == -1) {
        LL inv = qpow(lim, P - 2);
        for (int i = 0; i < lim; ++i)
            A[i] = A[i] * inv % P;
    }
}

inline void initFac(int n) {
    fac[0] = 1;
    for (int i = 1; i <= n; ++i)
        fac[i] = fac[i - 1] * i % P;
    ifac[n] = qpow(fac[n], P - 2);
    for (int i = n - 1; i; --i)
        ifac[i] = ifac[i + 1] * (i + 1) % P;
    ifac[0] = 1;
}

inline void initArray() {
    for (int i = 0, cur = 1; i <= low; ++i, cur *= -1)
        a[i] = (P + cur) * ifac[i] % P;
    for (int i = 0; i <= low; ++i)
        c[i] = fac[m] * ifac[m - i] % P * fac[n] % P * qpow(ifac[S], i) % P * ifac[n - i * S] % P * qpow(m - i, n - i * S) % P;
    std::reverse(c, c + low + 1);
}

inline void calcConvolution() {
    for (lim = 1; lim <= low << 1; lim <<= 1, ++l);
    for (int i = 0; i < lim; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    ntt(a, 1); ntt(c, 1);
    for (int i = 0; i < lim; ++i)
        a[i] = a[i] * c[i] % P;
    ntt(a, -1);
}

int main() {
    read(n, m, S);

    for (int i = 0; i <= m; ++i) read(w[i]);
    low = std::min(m, n / S);
    up = std::max(m, n);

    initFac(up);
    initArray();

    calcConvolution();

    LL ans = 0;
    for (int i = 0; i <= low; ++i)
        ans = (ans + ifac[i] * a[low - i] % P * w[i] % P) % P;
    writeln(ans);
    return 0;
}

你可能感兴趣的:(FFT)