先来吐槽两句:这篇文章本应该是发在博客园的,但是由于博客园的markdown没用明白,于是就只能继续用CSDN了。这段时间算是以赛代练吧,基本每场能打的CF都去打了,新建的小号分和以前的号差不多了,但是自我感觉水平还是没有恢复到之前。
题目链接:
https://codeforces.com/contest/1399/problem/F
题意:
t组询问,每组给你n条线段,保证线段两两不同,告诉你每条线段的左右端点坐标,要求你选出一组线段,使得这组线段两两之间的关系要么是相互包含,要么是没有交点,求组内最多能选择的线段数量。 ∑ n < = 3000 \sum n<=3000 ∑n<=3000 左右端点坐标 < = 2 e 5 <=2e5 <=2e5
题解:
一道有趣的区间dp题。
不难根据数据范围猜到要用一种 O ( n 2 ) O(n^2) O(n2)的算法。
因为坐标有点大,显然要先离散化。
当时打比赛的时候我也想过按照某种顺序排序然后dp,但是发现自己的二维状态总是会在转移时遇到问题,大概在设计状态和子问题的转移方面还是存在问题。
这个题的想法是用区间dp的思路,设 d p [ l ] [ r ] dp[l][r] dp[l][r]表示离散化后的区间 [ l , r ] [l,r] [l,r]最多能选出多少条符合题意的线段。一般的区间dp是要枚举中间断点来转移,这里这样就会复杂度炸掉。这里的转移方法比较独特,是先用一些vector,其中 v e c t o r [ l ] vector[l] vector[l]存储所有左端点在 l l l的线段(右端点不需要有序)。这样我们用一种记忆化搜索的写法,每次枚举到区间 [ l , r ] [l,r] [l,r]时,我们就考虑两种转移:
第一种是当前如果存在一个线段就是 [ l , r ] [l,r] [l,r],那么我们可以用它来覆盖 [ l + 1 , r ] [l+1,r] [l+1,r]的答案,故转移为 d p [ l , r ] = m a x ( d p [ l ] [ r ] , d p [ l + 1 ] [ r ] + 1 ) dp[l,r]=max(dp[l][r],dp[l+1][r]+1) dp[l,r]=max(dp[l][r],dp[l+1][r]+1)。转移时注意 l + 1 l+1 l+1是否小于 r r r。
第二种是枚举所有读入线段中左端点在 l l l的线段,如果它们的右端点大于 r r r,那么显然不能用来更新 d p [ l ] [ r ] dp[l][r] dp[l][r]的答案。如果它们的右端点等于 r r r,那么在第一种情况里我们已经更新过了,这里不要重复更新了。如果它们的右端点小于 r r r,那么这条线段可能可以更新答案。我们设 o p t opt opt为是否存在左端点为 l l l,右端点为 r r r的区间,存在则 o p t = 1 opt=1 opt=1,不存在则 o p t = 0 opt=0 opt=0,设当前枚举到的这条线段的右端点为 r ′ r' r′。那么更新的转移方程是 d p [ l ] [ r ] = m a x ( d p [ l ] [ r ] , o p t + d p [ l ] [ r ] ′ + d p [ r ′ + 1 ] [ r ] ) dp[l][r]=max(dp[l][r],opt+dp[l][r]'+dp[r'+1][r]) dp[l][r]=max(dp[l][r],opt+dp[l][r]′+dp[r′+1][r])
这样我们的转移就完成了。
下面分析一下复杂度。空间复杂度是 O ( n 2 ) O(n^2) O(n2)的自然不用多说,听说这题出题人可能开小空间限制了,所以好像有人被卡空间了。然后时间复杂度的话,总的转移复杂度是 O ( n ) O(n) O(n)级别的,因为第一张转移的总复杂度是 O ( n ) O(n) O(n),第二种转移的总复杂度也是 O ( n ) O(n) O(n),复杂度瓶颈大概是所有状态都可能被枚举到,于是是状态数 O ( n 2 ) O(n^2) O(n2)。
另外一个吐槽是,我这题一开始不会count操作,用map写,结果T了。
说明一下代码里的count函数,这个函数是数两个迭代器之间的元素里要查询的元素出现了几次。这里对于每个 l l l,只会进行一次count,所以总复杂度是 O ( n ) O(n) O(n)的,并没有因此爆复杂度。
先是能AC的用vector的代码:
#include
using namespace std;
int t,n,m,l[3010],r[3010],cnt,a[6010];
int dp[6010][6010];
vector<int> v[6010];
inline int cal(int l,int r)
{
if(dp[l][r]!=-1)
return dp[l][r];
dp[l][r]=0;
int pd=count(v[l].begin(),v[l].end(),r);
dp[l][r]=max(dp[l][r],pd+(l==r?0:cal(l+1,r)));
int mx=v[l].size();
for(int i=0;i<mx;++i)
{
int ri=v[l][i];
if(ri>=r)
continue;
dp[l][r]=max(dp[l][r],pd+cal(l,ri)+cal(ri+1,r));
}
return dp[l][r];
}
int main()
{
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
cnt=0;
for(int i=1;i<=n;++i)
{
scanf("%d%d",&l[i],&r[i]);
a[++cnt]=l[i];
a[++cnt]=r[i];
}
sort(a+1,a+cnt+1);
m=unique(a+1,a+cnt+1)-a-1;
for(int i=1;i<=n;++i)
{
l[i]=lower_bound(a+1,a+m+1,l[i])-a;
r[i]=lower_bound(a+1,a+m+1,r[i])-a;
v[l[i]].push_back(r[i]);
}
for(int i=1;i<=m;++i)
{
for(int j=1;j<=m;++j)
dp[i][j]=-1;
}
printf("%d\n",cal(1,m));
for(int i=1;i<=n;++i)
v[l[i]].clear();
for(int i=1;i<=cnt;++i)
a[i]=0;
}
return 0;
}
然后是T了,但是我觉得正确的map代码:
#include
using namespace std;
int t,n,m,l[3010],r[3010],cnt,a[6010];
int dp[6010][6010];
map<int,int> mp[6010];
inline int cal(int l,int r)
{
if(dp[l][r]!=-1)
return dp[l][r];
dp[l][r]=0;
int pd=mp[l][r];
dp[l][r]=max(dp[l][r],pd+(l==r?0:cal(l+1,r)));
map<int,int>::iterator it;
for(it=mp[l].begin();it!=mp[l].end();++it)
{
int ri=it->first;
if(ri>=r)
continue;
dp[l][r]=max(dp[l][r],pd+cal(l,ri)+cal(ri+1,r));
}
return dp[l][r];
}
int main()
{
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
cnt=0;
for(int i=1;i<=n;++i)
{
scanf("%d%d",&l[i],&r[i]);
a[++cnt]=l[i];
a[++cnt]=r[i];
}
sort(a+1,a+cnt+1);
m=unique(a+1,a+cnt+1)-a-1;
for(int i=1;i<=n;++i)
{
l[i]=lower_bound(a+1,a+m+1,l[i])-a;
r[i]=lower_bound(a+1,a+m+1,r[i])-a;
mp[l[i]][r[i]]=1;
}
for(int i=1;i<=m;++i)
{
for(int j=1;j<=m;++j)
dp[i][j]=-1;
}
printf("%d\n",cal(1,m));
for(int i=1;i<=n;++i)
mp[l[i]][r[i]]=0;
for(int i=1;i<=cnt;++i)
a[i]=0;
}
return 0;
}