2020-02-16模拟赛-2

前置知识(会的可以跳过)

1.第二类斯特林数

  S ( n , m ) \ S(n,m)  S(n,m)设为将   n \ n  n个不同的元素分为   m \ m  m个集合的方案数。也就时将集合   { x 1 , x 2 ⋯   , x n } \ \{x_{1},x_{2}\cdots,x_{n}\}  { x1,x2,xn}划分为   m \ m  m个非集合组成的集合   { A 1 , A 2 ⋯   , A m } \ \{A_{1},A_{2}\cdots,A_{m}\}  { A1,A2,Am}的方案数。(这里是集合的集合)

形象地说,将   n \ n  n个不同的小球放入   m \ m  m个相同的盒子的方案数。

第二类斯特林数有递推式   S ( n , m ) = S ( n − 1 , m − 1 ) + m S ( n − 1 , m ) \ S(n,m)=S(n-1,m-1)+mS(n-1,m)  S(n,m)=S(n1,m1)+mS(n1,m)

引理1

S ( n , m ) = 1 m ! ∑ i = 0 m ( − 1 ) i ( m i ) ( m − i ) n S(n,m)=\frac{1}{m !} \sum_{i=0}^{m} (-1)^{i}\binom{m}{i}(m-i)^{n} S(n,m)=m!1i=0m(1)i(im)(mi)n

证明
百度百科的优秀证明

这个引理使我们发现,第二类斯特林数的公式是卷积形式,所以我们可以利用   N T T \ NTT  NTT做到   O ( m log ⁡ m ) \ O(m \log{m})  O(mlogm)的复杂度内求出所有的   S ( n , i ) , i ∈ [ 0 , m ] \ S(n,i),i \in [0,m]  S(n,i),i[0,m]

引理2

x k = ∑ i = 0 k S ( k , i ) i ! ( x i ) x^{k}=\sum_{i=0}^{k}S(k,i)i!\binom{x}{i} xk=i=0kS(k,i)i!(ix)
证明:

左边是   k \ k  k个互不相同的球放在   x \ x  x个互不相同盒子里的方案数。

我们枚举非空盒子的数量   i \ i  i,非空盒子一共有   ( x i ) \ \binom{x}{i}  (ix)种取法,放置方法一共有   i ! S ( k , i ) \ i!S(k,i)  i!S(k,i)种。

2.分治   F F T \ FFT  FFT

我们有复杂度为   O ( n log ⁡ 2 ( n ) ) \ O(n\log^{2}(n))  O(nlog2(n))的算法将多个多项式乘在一起,其中这些多项式的最高次和为   n \ n  n

题解

引理3
假设我们的权值为   ∏ i = 1 M ( X i K i ) \ \prod_{i=1}^{M}\binom{X_{i}}{K_{i}}  i=1M(KiXi),则最终答案为   ( n + m − 1 m + ∑ i = 1 M K i − 1 ) \ \binom{n+m-1}{m+\sum_{i=1}^{M}K_{i}-1}  (m+i=1MKi1n+m1)

证明:

  ( X K ) \ \binom{X}{K}  (KX)相当于将   X + 1 \ X+1  X+1个球放到   K + 1 \ K+1  K+1个盒子中。
所以我们原本的意思是每个盒子中最开始有   1 \ 1  1个球,现在再将   N \ N  N个球放到   M \ M  M个盒子里,然后将第   i \ i  i个盒子里的球分为   K i \ K_{i}  Ki

显然这相当于把   N + M \ N+M  N+M个球分为   M + ∑ i = 1 M K i \ M+\sum_{i=1}^{M}K_{i}  M+i=1MKi个盒子里。(这里相当于将两次分组放到一起做)所以方案数为   ( n + m − 1 m + ∑ i = 1 M K i − 1 ) \ \binom{n+m-1}{m+\sum_{i=1}^{M}K_{i}-1}  (m+i=1MKi1n+m1)。而答案只与   ∑ i = 1 M K i \ \sum_{i=1}^{M}K_{i}  i=1MKi有关。

这样我们就得到了一个思路:

先利用引理2将   X i K i \ X_{i}^{K_{i}}  XiKi展开为组合数的和。显然转移为卷积形式。使用分治   F F T \ FFT  FFT求解即可。

代码

#include 
#define int long long
using namespace std;
const int mod = 998244353;
int fac[20002000], inv[20002000], kk[100100];
int aa[(1 << 20) + (2 << 2) + (2 >> 2) + (02 >> 1)];
int bb[(1 << 20) + (2 << 2) + (2 >> 2) + (02 >> 1)];
int t[(1 << 20) + (2 << 2) + (2 >> 2) + (02 >> 1)];
int nttl = 0;
int n, m;
struct nobe {
     
    int len;
    vector<int> pp;
    nobe(int l) : len(l) {
      pp.resize(l + 1); }
};
inline int poww(int a, int b, int mod) {
     
    int res = 1;
    while (b) {
     
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
inline void exgcd(int a, int b, int &x, int &y) {
     
    if (b == 0) {
     
        x = 1;
        y = 0;
        return;
    }
    exgcd(b, a % b, y, x);
    y -= x * (a / b);
}
inline int getinv(int a) {
     
    int x, y;
    exgcd(a, mod, x, y);
    return (x % mod + mod) % mod;
}
inline void init(int len) {
     
    int i = 0;
    while (i < nttl) {
     
        aa[i] = 0;
        bb[i] = 0;
        ++i;
    }
    nttl = 1;
    while (nttl <= len) nttl <<= 1;
}
inline void gett(int ty) {
     
    t[0] = 1;
    t[1] = poww(3, (mod - 1) / nttl, mod);
    if (ty)
        t[1] = getinv(t[1]);
    int i = 2;
    while (i < nttl) {
     
        t[i] = t[i - 1] * t[1] % mod;
        ++i;
    }
}
inline void ntt(int *pp, int ty) {
     
    int i = 0, j = 0, k = 0, l = 0, w = 0;
    while (i < nttl) {
     
        if (j > i)
            swap(pp[i], pp[j]);
        l = nttl >> 1;
        while ((j ^= l) < l) l >>= 1;
        ++i;
    }
    l = 2;
    while (l <= nttl) {
     
        j = l >> 1;
        int now = nttl / l;
        i = 0;
        while (i < nttl) {
     
            k = 0;
            int pw = 0;
            while (k < j) {
     
                int ww = pp[i + j + k] * t[pw] % mod, tt = pp[i + k];
                pp[i + j + k] = tt - ww + mod;
                pp[i + j + k] %= mod;
                pp[i + k] = tt + ww;
                pp[i + k] %= mod;
                pw += now;
                ++k;
            }
            i += l;
        }
        l <<= 1;
    }
    if (ty) {
     
        i = 0;
        int invl = getinv(nttl);
        while (i < nttl) {
     
            pp[i] = pp[i] % mod * invl % mod;
            ++i;
        }
    }
}
inline void mul() {
     
    gett(0);
    ntt(aa, 0);
    ntt(bb, 0);
    int i = 0;
    while (i < nttl) {
     
        aa[i] = aa[i] * bb[i] % mod;
        ++i;
    }
    gett(1);
    ntt(aa, 1);
}
inline nobe operator*(const nobe &a, const nobe &b) {
     
    nobe c(a.len + b.len);
    init(c.len);
    int i = 0;
    while (i <= a.len) {
     
        aa[i] = a.pp[i];
        ++i;
    }
    i = 0;
    while (i <= b.len) {
     
        bb[i] = b.pp[i];
        ++i;
    }
    mul();
    i = 0;
    while (i <= c.len) {
     
        c.pp[i] = aa[i];
        ++i;
    }
    return c;
}
inline nobe cdq(int l, int r) {
     
    int mid = l + ((r - l) >> 1);
    if (l == r) {
     
        nobe res(kk[l]);
        init(2 * kk[l]);
        int i = 0;
        while (i <= kk[l]) {
     
            aa[i] = inv[i] * poww(i, kk[l], mod) % mod;
            bb[i] = inv[i];
            if (i & 1)
                bb[i] = (-bb[i] + mod) % mod;
            ++i;
        }
        mul();
        i = 0;
        while (i <= kk[l]) {
     
            res.pp[i] = fac[i] * aa[i] % mod;
            ++i;
        }
        return res;
    }
    return cdq(l, mid) * cdq(mid + 1, r);
}
inline int getcc(int n, int m) {
     
    if (m > n)
        return 0;
    if (m < 0)
        return 0;
    return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
signed main() {
     
    freopen("b.in", "r", stdin);
    freopen("b.out", "w", stdout);
    scanf("%lld%lld", &m, &n);
    int i = 1;
    while (i <= m) {
     
        scanf("%lld", &kk[i]);
        ++i;
    }
    fac[0] = 1;
    i = 1;
    while (i <= 20000002) {
     
        fac[i] = fac[i - 1] * i % mod;
        ++i;
    }
    --i;
    inv[i] = getinv(fac[i]);
    --i;
    while (i >= 0) {
     
        inv[i] = inv[i + 1] * (i + 1) % mod;
        --i;
    }
    int ans = 0;
    nobe res = cdq(1, m);
    i = 0;
    while (i <= res.len) {
     
        ans += res.pp[i] * getcc(n + m - 1, m + i - 1);
        ans %= mod;
        ++i;
    }
    printf("%lld\n", ans);
    return 0;
}
/*
3 10
3 0 2
*/
/*
32967
*/

你可能感兴趣的:(数学,卷积,分治算法)