CF1740F Conditional Mix
有一个正整数 n n n和一个长度为 n n n的序列 a a a, 1 ≤ a i ≤ n 1\leq a_i\leq n 1≤ai≤n。
把每个 a i a_i ai看成一个一元集 { a i } \{a_i\} {ai},每次可以合并两个交集为空的集合。可以经过任意次合并。设合并完每个集合的元素个数组成可重集为 S S S,求 S S S的种类数,输出答案模 998244353 998244353 998244353。
1 ≤ n ≤ 2000 1\leq n\leq 2000 1≤n≤2000
首先我们通过补 0 0 0将这个可重集内元素个数变为 n n n。不妨设可重集内的元素是从大到小的。
对于可重集内两个位置 i , j i,j i,j,若其对应的集合的大小为 A i A_i Ai和 A j A_j Aj且 A i > A j A_i>A_j Ai>Aj,那么集合 A i A_i Ai中必定有一个元素可以取出并放到 A j A_j Aj。也就是说,对于两个可重集 A A A和 B B B,如果满足 ∀ 1 ≤ k ≤ n \forall1\leq k\leq n ∀1≤k≤n都满足 ∑ i = 1 k A i ≤ ∑ i = 1 k B i \sum\limits_{i=1}^kA_i\leq \sum\limits_{i=1}^kB_i i=1∑kAi≤i=1∑kBi,那么如果 A A A合法,则 B B B也合法。
那么,对于每一个 k k k,我们只需要找到 ∑ i = 1 k A i \sum\limits_{i=1}^{k}A_i i=1∑kAi的上界即可。如果一个数出现次数小于等于 k k k,那肯定都能选;如果出现次数大于 k k k,则最多只能取 k k k次。设 c n t i cnt_i cnti表示 i i i出现的次数,那么 ∑ i = 1 k A i ≤ ∑ i = 1 n min ( k , c n t i ) \sum\limits_{i=1}^kA_i\leq \sum\limits_{i=1}^n\min(k,cnt_i) i=1∑kAi≤i=1∑nmin(k,cnti)。接下来,问题就转化为满足条件的可重集 A A A的个数。
我们从大到小选数。设 f i , j , k f_{i,j,k} fi,j,k表示选了前 i i i个数,和为 j j j,可重集中最小的元素大于等于 k k k。那么转移如下
f i , j , k = f i , j , k + 1 + [ j ≥ k ] × f i , j − k , k f_{i,j,k}=f_{i,j,k+1}+[j \geq k]\times f_{i,j-k,k} fi,j,k=fi,j,k+1+[j≥k]×fi,j−k,k
转移为 O ( 1 ) O(1) O(1)的,但时间复杂度和空间复杂度为 O ( n 3 ) O(n^3) O(n3)。我们考虑优化。
对于空间,我们可以用滚动数组来省去一维。
设当前可重集中的元素分别为 A 1 , A 2 , … , A i A_1,A_2,\dots,A_i A1,A2,…,Ai,显然 A 1 , A 2 , … , A i ≥ k A_1,A_2,\dots,A_i\geq k A1,A2,…,Ai≥k,所以 ∑ j = 1 i A j ≥ i × k \sum\limits_{j=1}^iA_j\geq i\times k j=1∑iAj≥i×k。又因为 ∑ j = 1 i A j ≤ n \sum\limits_{j=1}^iA_j\leq n j=1∑iAj≤n,所以 i × k ≥ n i\times k\geq n i×k≥n,即 k ≤ n i k\leq \dfrac ni k≤in。也就是说,每次枚举 k k k的时间复杂度为 O ( n i ) O(\dfrac ni) O(in),那么总时间复杂度为 O ( n 2 ln n ) O(n^2\ln n) O(n2lnn)。
#include
using namespace std;
const int N=2000;
int n,x,cnt[2005],w[2005];
long long ans,f[2][2005][2005];
long long mod=998244353;
bool cmp(int ax,int bx){
return ax>bx;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&x);++cnt[x];
}
sort(cnt+1,cnt+n+1,cmp);
for(int k=1;k<=n;k++){
for(int i=1;i<=n;i++) w[k]+=min(k,cnt[i]);
}
for(int i=0;i<=n;i++) f[0][0][i]=1;
for(int i=1,e=1;i<=n;i++,e^=1){
for(int j=0;j<=w[i];j++){
f[e][j][n/i+1]=0;
}
for(int k=n/i;k>=0;k--){
for(int j=0;j<=w[i];j++){
f[e][j][k]=f[e][j][k+1];
if(j>=k) f[e][j][k]=(f[e][j][k]+f[e^1][j-k][k])%mod;
}
}
}
printf("%lld",f[n&1][n][0]);
return 0;
}