推出公式以后就可以ntt了。。。
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int mod = 998244353; const int maxn = 700005; LL dp[100005],dp2[100005]; LL a[maxn], b[maxn], c[maxn], xp[maxn]; LL f[100005], g[100005]; LL powmod(LL a, LL b, LL p) { LL res = 1, base = a; while(b) { if(b % 2) res = res * base % mod; base = base * base % mod; b /= 2; } return res; } namespace NTT { const int r = 3, gl = 25; LL p, rp[50], irp[50]; void setMod(LL _p = 998244353) { p = _p; for(int i = 0; i < gl; i++) rp[i] = powmod(r, (p-1)/(1<<i), p); } void FFT(LL a[], int n, LL wt[] = rp) { for(int i = 0, j = 0; i < n; i++) { if(j > i) swap(a[i], a[j]); int k = n; while(j & (k >>= 1)) j &= ~k; j |= k; } for(int m = 1, b = 1; m < n; m<<=1, b++) for(int k = 0, w = 1; k < m; ++k) { for(int i = k; i < n; i += m<<1) { int v = a[i+m] * w % p; if((a[i+m] = a[i] - v) < 0) a[i+m] += p; if((a[i] += v) >= p) a[i] -= p; } w = w * wt[b] % p; } } void IFFT(LL a[], int n) { for(int i = 0; i < gl; i++) irp[i] = powmod(rp[i], n-1, p); FFT(a, n, irp); LL inv = powmod(n, p-2, p); for(int i = 0; i < n; i++) a[i] = a[i] * inv % p; } void Mul(LL a[], LL b[], LL n, LL c[]) { FFT(a, n);FFT(b, n); for(int i = 0; i < n; i++) c[i] = a[i] * b[i] % p; IFFT(c, n); } } LL f1(int x) { return dp[x] * g[x] % mod; } LL f2(int x) { return xp[x] * g[x-1] % mod; } void cdq(int L, int R) { if(L == R) { if(L) dp[L] = dp[L] * f[L-1] % mod; return; } int mid = (L + R) >> 1; cdq(L, mid); int N = R - L + 1, n = 1; while(2 * N >= n) n *= 2; for(int i = 0; i <= n; i++) a[i] = b[i] = c[i] = 0; for(int i = 0; i < mid-L+1; i++) a[i] = f1(i + L); for(int i = 1; i <= N; i++) b[i] = f2(i); NTT::Mul(a, b, n, c); for(int i = mid+1; i <= R; i++) dp[i] = (dp[i] + c[i-L]) % mod; cdq(mid+1, R); } void init() { int N = 100000; memset(dp, 0, sizeof dp); memset(dp2, 0, sizeof dp2); f[0] = 1; for(int i = 1; i <= N; i++) f[i] = f[i-1] * i % mod; g[N] = powmod(f[N], mod-2, mod); for(int i = N-1; i >= 0; i--) g[i] = g[i+1] * (i+1) % mod; xp[1] = xp[2] = 1; for(int i = 3; i <= N; i++) xp[i] = powmod(i, i-2, mod); dp[0] = 1; NTT::setMod(); cdq(0, N); int n = 1; while(2 * N >= n) n *= 2; memset(a, 0, sizeof a); memset(b, 0, sizeof b); memset(c, 0, sizeof c); for(int i = 0; i <= N; i++) a[i] = (dp[i] * g[i]) % mod; for(int i = 2; i <= N; i++) b[i] = (xp[i] * g[i-2]) % mod; NTT::Mul(a, b, n, c); dp2[0] = dp2[1] = 1; for(int i = 2; i <= N; i++) dp2[i] = (c[i] * f[i-2]) % mod; } void work() { int m, x; scanf("%d", &m); LL ans = 1, res = 1; for(int i = 1; i <= m; i++) { scanf("%d", &x); ans = ans * dp[x] % mod; res = res * dp2[x] % mod; } ans = ans * powmod(2, m, mod) % mod; ans = ((ans - res) % mod + mod) % mod; printf("%lld\n", ans); } int main() { init(); int _; scanf("%d", &_); while(_--) work(); return 0; }