CF1740F Conditional Mix

CF1740F Conditional Mix

题目大意

有一个正整数 n n n和一个长度为 n n n的序列 a a a 1 ≤ a i ≤ n 1\leq a_i\leq n 1ain

把每个 a i a_i ai看成一个一元集 { a i } \{a_i\} {ai},每次可以合并两个交集为空的集合。可以经过任意次合并。设合并完每个集合的元素个数组成可重集为 S S S,求 S S S的种类数,输出答案模 998244353 998244353 998244353

1 ≤ n ≤ 2000 1\leq n\leq 2000 1n2000


题解

首先我们通过补 0 0 0将这个可重集内元素个数变为 n n n。不妨设可重集内的元素是从大到小的。

对于可重集内两个位置 i , j i,j i,j,若其对应的集合的大小为 A i A_i Ai A j A_j Aj A i > A j A_i>A_j Ai>Aj,那么集合 A i A_i Ai中必定有一个元素可以取出并放到 A j A_j Aj。也就是说,对于两个可重集 A A A B B B,如果满足 ∀ 1 ≤ k ≤ n \forall1\leq k\leq n ∀1kn都满足 ∑ i = 1 k A i ≤ ∑ i = 1 k B i \sum\limits_{i=1}^kA_i\leq \sum\limits_{i=1}^kB_i i=1kAii=1kBi,那么如果 A A A合法,则 B B B也合法。

那么,对于每一个 k k k,我们只需要找到 ∑ i = 1 k A i \sum\limits_{i=1}^{k}A_i i=1kAi的上界即可。如果一个数出现次数小于等于 k k k,那肯定都能选;如果出现次数大于 k k k,则最多只能取 k k k次。设 c n t i cnt_i cnti表示 i i i出现的次数,那么 ∑ i = 1 k A i ≤ ∑ i = 1 n min ⁡ ( k , c n t i ) \sum\limits_{i=1}^kA_i\leq \sum\limits_{i=1}^n\min(k,cnt_i) i=1kAii=1nmin(k,cnti)。接下来,问题就转化为满足条件的可重集 A A A的个数。

我们从大到小选数。设 f i , j , k f_{i,j,k} fi,j,k表示选了前 i i i个数,和为 j j j,可重集中最小的元素大于等于 k k k。那么转移如下

f i , j , k = f i , j , k + 1 + [ j ≥ k ] × f i , j − k , k f_{i,j,k}=f_{i,j,k+1}+[j \geq k]\times f_{i,j-k,k} fi,j,k=fi,j,k+1+[jk]×fi,jk,k

转移为 O ( 1 ) O(1) O(1)的,但时间复杂度和空间复杂度为 O ( n 3 ) O(n^3) O(n3)。我们考虑优化。

对于空间,我们可以用滚动数组来省去一维。

设当前可重集中的元素分别为 A 1 , A 2 , … , A i A_1,A_2,\dots,A_i A1,A2,,Ai,显然 A 1 , A 2 , … , A i ≥ k A_1,A_2,\dots,A_i\geq k A1,A2,,Aik,所以 ∑ j = 1 i A j ≥ i × k \sum\limits_{j=1}^iA_j\geq i\times k j=1iAji×k。又因为 ∑ j = 1 i A j ≤ n \sum\limits_{j=1}^iA_j\leq n j=1iAjn,所以 i × k ≥ n i\times k\geq n i×kn,即 k ≤ n i k\leq \dfrac ni kin。也就是说,每次枚举 k k k的时间复杂度为 O ( n i ) O(\dfrac ni) O(in),那么总时间复杂度为 O ( n 2 ln ⁡ n ) O(n^2\ln n) O(n2lnn)

code

#include
using namespace std;
const int N=2000;
int n,x,cnt[2005],w[2005];
long long ans,f[2][2005][2005];
long long mod=998244353;
bool cmp(int ax,int bx){
	return ax>bx;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d",&x);++cnt[x];
	}
	sort(cnt+1,cnt+n+1,cmp);
	for(int k=1;k<=n;k++){
		for(int i=1;i<=n;i++) w[k]+=min(k,cnt[i]);
	}
	for(int i=0;i<=n;i++) f[0][0][i]=1;
	for(int i=1,e=1;i<=n;i++,e^=1){
		for(int j=0;j<=w[i];j++){
			f[e][j][n/i+1]=0;
		}
		for(int k=n/i;k>=0;k--){
			for(int j=0;j<=w[i];j++){
				f[e][j][k]=f[e][j][k+1];
				if(j>=k) f[e][j][k]=(f[e][j][k]+f[e^1][j-k][k])%mod;
			}
		}
	}
	printf("%lld",f[n&1][n][0]);
	return 0;
}

你可能感兴趣的:(题解,c++,题解)