CF1399F Yet Another Segments Subset 区间DP

先来吐槽两句:这篇文章本应该是发在博客园的,但是由于博客园的markdown没用明白,于是就只能继续用CSDN了。这段时间算是以赛代练吧,基本每场能打的CF都去打了,新建的小号分和以前的号差不多了,但是自我感觉水平还是没有恢复到之前。

题目链接:
https://codeforces.com/contest/1399/problem/F

题意:
t组询问,每组给你n条线段,保证线段两两不同,告诉你每条线段的左右端点坐标,要求你选出一组线段,使得这组线段两两之间的关系要么是相互包含,要么是没有交点,求组内最多能选择的线段数量。 ∑ n < = 3000 \sum n<=3000 n<=3000 左右端点坐标 < = 2 e 5 <=2e5 <=2e5

题解:
一道有趣的区间dp题。

不难根据数据范围猜到要用一种 O ( n 2 ) O(n^2) O(n2)的算法。

因为坐标有点大,显然要先离散化。

当时打比赛的时候我也想过按照某种顺序排序然后dp,但是发现自己的二维状态总是会在转移时遇到问题,大概在设计状态和子问题的转移方面还是存在问题。

这个题的想法是用区间dp的思路,设 d p [ l ] [ r ] dp[l][r] dp[l][r]表示离散化后的区间 [ l , r ] [l,r] [l,r]最多能选出多少条符合题意的线段。一般的区间dp是要枚举中间断点来转移,这里这样就会复杂度炸掉。这里的转移方法比较独特,是先用一些vector,其中 v e c t o r [ l ] vector[l] vector[l]存储所有左端点在 l l l的线段(右端点不需要有序)。这样我们用一种记忆化搜索的写法,每次枚举到区间 [ l , r ] [l,r] [l,r]时,我们就考虑两种转移:

第一种是当前如果存在一个线段就是 [ l , r ] [l,r] [l,r],那么我们可以用它来覆盖 [ l + 1 , r ] [l+1,r] [l+1,r]的答案,故转移为 d p [ l , r ] = m a x ( d p [ l ] [ r ] , d p [ l + 1 ] [ r ] + 1 ) dp[l,r]=max(dp[l][r],dp[l+1][r]+1) dp[l,r]=max(dp[l][r],dp[l+1][r]+1)。转移时注意 l + 1 l+1 l+1是否小于 r r r

第二种是枚举所有读入线段中左端点在 l l l的线段,如果它们的右端点大于 r r r,那么显然不能用来更新 d p [ l ] [ r ] dp[l][r] dp[l][r]的答案。如果它们的右端点等于 r r r,那么在第一种情况里我们已经更新过了,这里不要重复更新了。如果它们的右端点小于 r r r,那么这条线段可能可以更新答案。我们设 o p t opt opt为是否存在左端点为 l l l,右端点为 r r r的区间,存在则 o p t = 1 opt=1 opt=1,不存在则 o p t = 0 opt=0 opt=0,设当前枚举到的这条线段的右端点为 r ′ r' r。那么更新的转移方程是 d p [ l ] [ r ] = m a x ( d p [ l ] [ r ] , o p t + d p [ l ] [ r ] ′ + d p [ r ′ + 1 ] [ r ] ) dp[l][r]=max(dp[l][r],opt+dp[l][r]'+dp[r'+1][r]) dp[l][r]=max(dp[l][r],opt+dp[l][r]+dp[r+1][r])

这样我们的转移就完成了。

下面分析一下复杂度。空间复杂度是 O ( n 2 ) O(n^2) O(n2)的自然不用多说,听说这题出题人可能开小空间限制了,所以好像有人被卡空间了。然后时间复杂度的话,总的转移复杂度是 O ( n ) O(n) O(n)级别的,因为第一张转移的总复杂度是 O ( n ) O(n) O(n),第二种转移的总复杂度也是 O ( n ) O(n) O(n),复杂度瓶颈大概是所有状态都可能被枚举到,于是是状态数 O ( n 2 ) O(n^2) O(n2)

另外一个吐槽是,我这题一开始不会count操作,用map写,结果T了。

说明一下代码里的count函数,这个函数是数两个迭代器之间的元素里要查询的元素出现了几次。这里对于每个 l l l,只会进行一次count,所以总复杂度是 O ( n ) O(n) O(n)的,并没有因此爆复杂度。

先是能AC的用vector的代码:

#include 
using namespace std;

int t,n,m,l[3010],r[3010],cnt,a[6010];
int dp[6010][6010];
vector<int> v[6010];
inline int cal(int l,int r)
{
     
	if(dp[l][r]!=-1)
	return dp[l][r];
	dp[l][r]=0;
	int pd=count(v[l].begin(),v[l].end(),r);
	dp[l][r]=max(dp[l][r],pd+(l==r?0:cal(l+1,r)));
	int mx=v[l].size();
	for(int i=0;i<mx;++i)
	{
     
		int ri=v[l][i];
		if(ri>=r)
		continue;
		dp[l][r]=max(dp[l][r],pd+cal(l,ri)+cal(ri+1,r));
	}
	return dp[l][r];
}
int main()
{
     
	scanf("%d",&t);
	while(t--)
	{
     
		scanf("%d",&n);
		cnt=0;
		for(int i=1;i<=n;++i)
		{
     
			scanf("%d%d",&l[i],&r[i]);
			a[++cnt]=l[i];
			a[++cnt]=r[i];
		}
		sort(a+1,a+cnt+1);
		m=unique(a+1,a+cnt+1)-a-1;
		for(int i=1;i<=n;++i)
		{
     
			l[i]=lower_bound(a+1,a+m+1,l[i])-a;
			r[i]=lower_bound(a+1,a+m+1,r[i])-a;
			v[l[i]].push_back(r[i]);
		}
		for(int i=1;i<=m;++i)
		{
     
			for(int j=1;j<=m;++j)
			dp[i][j]=-1;
		}
		printf("%d\n",cal(1,m));
		for(int i=1;i<=n;++i)
		v[l[i]].clear();
		for(int i=1;i<=cnt;++i)
		a[i]=0;
	}
	return 0;
}

然后是T了,但是我觉得正确的map代码:

#include 
using namespace std;

int t,n,m,l[3010],r[3010],cnt,a[6010];
int dp[6010][6010];
map<int,int> mp[6010];
inline int cal(int l,int r)
{
     
	if(dp[l][r]!=-1)
	return dp[l][r];
	dp[l][r]=0;
	int pd=mp[l][r];
	dp[l][r]=max(dp[l][r],pd+(l==r?0:cal(l+1,r)));
	map<int,int>::iterator it;
	for(it=mp[l].begin();it!=mp[l].end();++it)
	{
     
		int ri=it->first;
		if(ri>=r)
		continue;
		dp[l][r]=max(dp[l][r],pd+cal(l,ri)+cal(ri+1,r));
	}
	return dp[l][r];
}
int main()
{
     
	scanf("%d",&t);
	while(t--)
	{
     
		scanf("%d",&n);
		cnt=0;
		for(int i=1;i<=n;++i)
		{
     
			scanf("%d%d",&l[i],&r[i]);
			a[++cnt]=l[i];
			a[++cnt]=r[i];
		}
		sort(a+1,a+cnt+1);
		m=unique(a+1,a+cnt+1)-a-1;
		for(int i=1;i<=n;++i)
		{
     
			l[i]=lower_bound(a+1,a+m+1,l[i])-a;
			r[i]=lower_bound(a+1,a+m+1,r[i])-a;
			mp[l[i]][r[i]]=1;
		}
		for(int i=1;i<=m;++i)
		{
     
			for(int j=1;j<=m;++j)
			dp[i][j]=-1;
		}
		printf("%d\n",cal(1,m));
		for(int i=1;i<=n;++i)
		mp[l[i]][r[i]]=0;
		for(int i=1;i<=cnt;++i)
		a[i]=0;
	}
	return 0;
}

你可能感兴趣的:(dp)