https://www.lydsy.com/JudgeOnline/problem.php?id=5306
先回顾一下「 {1,2,...,M} { 1 , 2 , . . . , M } 个中恰好 K K 个合法」的容斥求法:
#include
#include
#include
#include
#include
#define For(i, a, b) for (i = a; i <= b; i++)
#define Step(i, a, b, x) for (i = a; i <= b; i += x)
#define Pow(k, n) for (k = 1; k < n; k <<= 1)
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
const int N = 3e5 + 5, M = 1e7 + 5, ZZQ = 1004535809;
int n, m, s, w[N], f[N], g[N], h[N], fac[M], inv[M],
spw[N], rev[N], ff = 1, gg, tot, gp[N], res[N], ans;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = 1ll * res * a % ZZQ;
a = 1ll * a * a % ZZQ;
b >>= 1;
}
return res;
}
void FFT(int n, int *a, int op) {
int i, j, k, sp = n;
gp[n] = qpow(op == 1 ? 3 : 334845270, (ZZQ - 1) / n);
For (i, 0, n - 1) if (i < rev[i]) swap(a[i], a[rev[i]]);
For (i, 1, tot) sp >>= 1,
gp[sp] = 1ll * gp[sp << 1] * gp[sp << 1] % ZZQ;
Pow(k, n) {
int x = gp[k << 1];
Step (i, 0, n - 1, k << 1) {
int w = 1;
For (j, 0, k - 1) {
int u = a[i + j], v = 1ll * w * a[i + j + k] % ZZQ;
a[i + j] = (u + v) % ZZQ;
a[i + j + k] = (u - v + ZZQ) % ZZQ;
w = 1ll * w * x % ZZQ;
}
}
}
}
int main() {
int i; fac[0] = inv[0] = inv[1] = spw[0] = 1;
n = read(); m = read(); s = read();
For (i, 0, m) w[i] = read();
For (i, 1, max(n, m)) fac[i] = 1ll * fac[i - 1] * i % ZZQ;
For (i, 2, max(n, m))
inv[i] = 1ll * (ZZQ - ZZQ / i) * inv[ZZQ % i] % ZZQ;
For (i, 2, max(n, m)) inv[i] = 1ll * inv[i] * inv[i - 1] % ZZQ;
For (i, 1, m) spw[i] = 1ll * spw[i - 1] * inv[s] % ZZQ;
For (i, 0, m) {
f[i] = i & 1 ? ZZQ - 1 : 1;
f[i] = 1ll * f[i] * inv[i] % ZZQ * spw[i] % ZZQ;
}
For (i, 0, m) {
if (n + (i - m) * s < 0) continue;
g[i] = qpow(i, n + (i - m) * s);
g[i] = 1ll * g[i] * inv[i] % ZZQ * inv[n + (i - m) * s] % ZZQ;
}
For (i, 0, m) h[i] = 1ll * w[i] * inv[i] % ZZQ * spw[i] % ZZQ;
while (ff <= (m << 1)) ff <<= 1, tot++;
gg = qpow(ff, ZZQ - 2);
For (i, 0, ff - 1)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << tot - 1);
FFT(ff, f, 1); FFT(ff, g, 1);
For (i, 0, ff - 1) res[i] = 1ll * f[i] * g[i] % ZZQ;
FFT(ff, res, -1);
For (i, 0, ff - 1) res[i] = 1ll * res[i] * gg % ZZQ;
For (i, 0, m) ans = (ans + 1ll * h[i] * res[m - i] % ZZQ) % ZZQ;
cout << 1ll * fac[m] * fac[n] % ZZQ * ans % ZZQ << endl;
return 0;
}