题目链接:https://ac.nowcoder.com/acm/contest/1080/F
题目描述
还记得校赛的"Protoss and Zerg"吗?(https://ac.nowcoder.com/acm/contest/303/H)
这是另一个版本。
---------------以下为原题目描述(稍有修改)----------------
1v1,是星际争霸(StarCraft)中最常见的竞技模式。
tokitsukaze进行了n场1v1。在每一场的1v1中,她都有星灵(Protoss)和异虫(Zerg)两个种族可以选择,分别有a个单位和b个单位。因为tokitsukaze不太擅长玩人类(Terran),所以她肯定不会选择人类。
对于每一场1v1,玩家只能控制己方单位。也就是说,如果选择虫族,那么只能控制虫族单位,如果玩家选择星灵,那么只能控制星灵单位。
在n场1v1中,假设第i场,有ai个星灵单位,和bi个虫族单位。tokitsukaze可以在一场1v1中,任选一种种族进行游戏。如果选择了星灵,那么在这场游戏中,可以选择出兵1到ai个单位。那么同理,如果选择了虫族,那么在这场游戏中,可以选择出兵1到bi个单位。
假设所有星灵单位互不相同,所有异虫单位也互不相同,那么请问tokitsukaze打完这n场1v1,出兵的总方案数是多少。
注意:若两个方案,有其中一个单位不同,即视为不相同。
---------------以上为原题目描述(稍有修改)----------------
但是,tokitsukaze认为这个问题太简单了。
tokitsukaze想知道,恰好选择了0,1,2,3,…,个星灵单位的方案数分别是多少。由于答案很大,所以输出答案mod 998244353 后的结果。
数据保证。
题解:(先考虑原题目怎么做:每一场是独立的,每一场出兵方案数为 2 a i + 2 b i − 2 2^{a_i} + 2 ^ {b_i} - 2 2ai+2bi−2,答案是每一场的乘积,是一个简单的组合计数问题)
每一场出 x x x ( x > 0 x > 0 x>0) 个星灵单位的方案数为: C a i x C_{a_i}^{x} Caix,出 0 个星灵单位的方案为 : 2 b i − 1 2 ^ {b_i} - 1 2bi−1。每一场出 i i i 个星灵单位的方案数可以构造一个多项式: ( 2 b i − 1 ∗ x 0 , C a i 1 ∗ x 1 , C a i 2 ∗ x 2 , . . . , C a i a i ∗ x a i ) (2^{b_i} - 1 * x_0,C_{a_i}^1 * x_1,C_{a_i}^2 * x_2,...,C_{a_i}^{a_i} * x_{a_i}) (2bi−1∗x0,Cai1∗x1,Cai2∗x2,...,Caiai∗xai)这个多项式的系数代表第 i 场,出 0 到 a i a_i ai个星灵单位的方案数。将它记为: ( a 0 ∗ x 0 , a 1 ∗ x 1 , a 2 ∗ x 2 , . . . , a n ∗ x n ) (a_0 * x_0,a_1 * x_1,a_2 * x_2,...,a_n * x_n) (a0∗x0,a1∗x1,a2∗x2,...,an∗xn)任意两场出的星灵单位之和的方案数: ( ∑ i = 0 0 a i ∗ b 0 − i , ∑ i = 0 1 a i ∗ b 1 − i , ∑ i = 0 2 a i ∗ b 2 − i , . . , ∑ i = 0 t o t a i ∗ b t o t − i ) (\sum_{i = 0}^{0} a_i * b_{0 - i},\sum_{i = 0}^{1} a_i * b_{1 - i},\sum_{i = 0}^{2} a_i * b_{2 - i},..,\sum_{i = 0}^{tot} a_i * b_{tot - i}) (i=0∑0ai∗b0−i,i=0∑1ai∗b1−i,i=0∑2ai∗b2−i,..,i=0∑totai∗btot−i)(tot是两场的星灵单位的总和)这是一个卷积式子,可以通过n次卷积得到最后的答案,由于 ∑ a i ≤ 2 ∗ 1 0 5 \sum a_i \leq 2*10^5 ∑ai≤2∗105,构造多项式不会超时,但由于两个式子做卷积,复杂度与两个式子长度之和相关,n次NTT可能会超时,考虑类似哈夫曼树的启发式合并:每次选 a i a_i ai最小的两场做卷积合并,复杂度为 O ( s u m a i ) ∗ l o g ( s u m a i ) ) ∗ l o g ( n ) O(sum_{a_i}) * log(sum_{a_i})) * log(n) O(sumai)∗log(sumai))∗log(n)
#include
using namespace std;
const int maxn = 2e5 + 10;
const int mod = 998244353;
typedef long long ll;
struct ss{
int a,b;
bool operator < (const ss & rhs) const {
return a < rhs.a;
}
}t[maxn];
struct tt{
int id,sz;
tt(int i = 0,int j = 0) {
id = i;sz = j;
}
bool operator < (const tt & rhs) const{
return rhs.sz < sz;
}
};
int n;
ll a[maxn * 10],b[maxn * 10];
ll pw[2 * maxn],fact[maxn * 2],ifact[maxn * 2];
vector<int> g[maxn];
priority_queue<tt> q;
ll fpow(ll a,ll b) {
ll r = 1;
while(b) {
if(b & 1) r = r * a % mod;
a = a * a % mod;
b >>= 1;
}
return r;
}
void change(ll t[],int len) {
for(int i = 1, j = len / 2; i < len - 1; i++) {
if(i < j) swap(t[i],t[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void NTT(ll t[],int len,int type) {
change(t,len);
for(int s = 2; s <= len; s <<= 1) {
ll wn = fpow(3,(mod - 1) / s);
if(type == -1) wn = fpow(wn,mod - 2);
for(int j = 0; j < len; j += s) {
ll w = 1;
for(int k = 0; k < s / 2; k++) {
ll u = t[j + k],v = t[j + k + s / 2] * w % mod;
t[j + k] = (u + v) % mod;
t[j + k + s / 2] = (u - v + mod) % mod;
w = w * wn % mod;
}
}
}
if(type == -1) {
ll inv = fpow(len,mod - 2);
for(int i = 0; i < len; i++)
t[i] = t[i] * inv % mod;
}
}
ll C(int a,int b) {
if(a < b) return 0;
return (fact[a] * ifact[b] % mod) * ifact[a - b] % mod;
}
int main() {
scanf("%d",&n);
pw[0] = 1;
int sum = 0;
for(int i = 1; i <= maxn; i++)
pw[i] = pw[i - 1] * 2 % mod;
fact[0] = 1;
for(int i = 1; i <= maxn; i++)
fact[i] = fact[i - 1] * i % mod;
ifact[maxn] = fpow(fact[maxn],mod - 2);
for(int i = maxn - 1; i >= 0; i--)
ifact[i] = ifact[i + 1] * (i + 1) % mod;
for(int i = 1; i <= n; i++) {
scanf("%d",&t[i].a);
sum += t[i].a;
}
for(int i = 1; i <= n; i++)
scanf("%d",&t[i].b);
for(int i = 1; i <= n; i++) {
g[i].push_back((pw[t[i].b] + mod - 1) % mod);
for(int j = 1; j <= t[i].a; j++)
g[i].push_back(C(t[i].a,j));
}
for(int i = 1; i <= n; i++) q.push(tt(i,g[i].size()));
while(q.size() >= 2) {
tt t1 = q.top();q.pop();
tt t2 = q.top();q.pop();
int l = t1.sz + t2.sz - 2;
int len = 1;
for(int i = 0; i < g[t1.id].size(); i++)
a[i] = g[t1.id][i];
for(int i = 0; i < g[t2.id].size(); i++)
b[i] = g[t2.id][i];
while(len <= l) len <<= 1;
for(int i = g[t1.id].size(); i < len; i++) a[i] = 0;
for(int i = g[t2.id].size(); i < len; i++) b[i] = 0;
g[t1.id].clear();g[t2.id].clear();
NTT(a,len,1);NTT(b,len,1);
for(int i = 0; i < len; i++)
a[i] = a[i] * b[i] % mod;
NTT(a,len,-1);
for(int i = 0; i <= l; i++)
g[t1.id].push_back(a[i]);
t1.sz = g[t1.id].size();
q.push(t1);
}
tt t = q.top();
for(int i = 0; i < t.sz; i++) {
if(i) printf(" ");
printf("%lld",g[t.id][i]);
}
return 0;
}