BZOJ4919 大根堆 线段树合并 二分 离散化

题目链接

题意:
给你一棵树,每个点有点权,问你最多能选出多少个点,使得所有选出的点中子节点的权值都比父节点小(严格小于)。点数2e5,权值1e9

题解:
首先的一个暴力是用一个树形dp, d p [ x ] [ i ] dp[x][i] dp[x][i]表示点 x x x为根的子树内,最大权值是 i i i时子树内最多选的点数。我们不难发现,随着这个 i i i的增大,最多选出的点数也是单调不降的。于是我们考虑从子节点转移过来, d p [ x ] [ i ] = m a x ( ∑ y ∈ s o n [ x ] d p [ y ] [ i ] , ∑ y ∈ s o n [ x ] d p [ x ] [ i − 1 ] + 1 ) dp[x][i]=max(\sum_{y\in son[x]}dp[y][i],\sum_{y\in son[x]}dp[x][i-1]+1) dp[x][i]=max(yson[x]dp[y][i],yson[x]dp[x][i1]+1)。前一种是不选当前点,后一种是选了当前点,应该不难理解。我们发现其实第二维只与点的大小关系有关,与具体的点权值没有关系,于是我们离散化一下,就可以得到一个 O ( n 2 ) O(n^2) O(n2)的算法。

我们考虑能不能对这个dp进行优化。我们发现,如果对于当前 x x x,你有了每一个 i i i ∑ y ∈ s o n [ x ] d p [ y ] [ i ] \sum_{y\in son[x]}dp[y][i] yson[x]dp[y][i],那么我们考虑当前点更新答案时可能更新哪些点。我们不难发现,由于 ∑ y ∈ s o n [ x ] d p [ y ] [ i ] \sum_{y\in son[x]}dp[y][i] yson[x]dp[y][i]是单调递增的,所以能更新的区间也一定是一段连续的区间。而这个连续的区间的一个端点一定是当前 x x x点的值,显然是不可能更新更小的权值时的答案的。于是我们可以二分一下右端点。我们发现,我们找到这个区间后就只需要维护一个区间加的操作,这个可以线段树来做。而对于整个树形结构,我们可以用一个线段树合并来从儿子到父亲更新答案,线段树的权值 i i i维护的就是 d p [ x ] [ i ] dp[x][i] dp[x][i]的答案。这样就可以做到 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),是可以通过本题的。

然而由于空间复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)的,于是我因为数组开得过大而TLE掉了。。。

代码:

#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

int n,fa[200010],hed[200010],cnt,num,val[200010],ans;
int b[200010],root[200010],qwq;
struct node
{
	int to,next;
}a[400010];
struct tree
{
	int l,r,mx;
}tr[4000100];
inline int read()
{
	int x=0;
	char s=getchar();
	while(s>'9'||s<'0')
	s=getchar();
	while(s>='0'&&s<='9')
	{
		x=x*10+s-'0';
		s=getchar();
	}
	return x; 
} 
inline void add(int from,int to)
{
	a[++cnt].to=to;
	a[cnt].next=hed[from];
	hed[from]=cnt;
}
inline void merge(int &l,int r)
{
	if(!l||!r)
	{
		l=l+r;
		return;
	}
	tr[l].mx=tr[l].mx+tr[r].mx;
	merge(tr[l].l,tr[r].l);
	merge(tr[l].r,tr[r].r);
}
inline int query(int rt,int l,int r,int x)
{
	if(!rt||!x)
	return 0;
	if(l==r)
	return tr[rt].mx;
	int mid=(l+r)>>1,res=tr[rt].mx;
	if(x<=mid)
	res+=query(tr[rt].l,l,mid,x);
	else
	res+=query(tr[rt].r,mid+1,r,x);
	return res;
}
inline void update(int &rt,int l,int r,int le,int ri)
{
	if(!rt)
	rt=++num;
	if(le<=l&&r<=ri)
	{
		tr[rt].mx++;
		return;
	}
	int mid=(l+r)>>1;
	if(le<=mid)
	update(tr[rt].l,l,mid,le,ri);
	if(mid+1<=ri)
	update(tr[rt].r,mid+1,r,le,ri);
}
inline void dfs(int x)
{
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		dfs(y);
		merge(root[x],root[y]);
	}
	int ji=query(root[x],1,qwq,val[x]-1)+1;
	if(ji<=query(root[x],1,qwq,val[x]))
	return;
	int l=val[x],r=qwq,mid,res=val[x];
	while(l<=r)
	{
		mid=(l+r)>>1;
		if(query(root[x],1,qwq,mid)<ji)
		{
			res=mid;
			l=mid+1;
		}
		else
		r=mid-1;
	}
	update(root[x],1,qwq,val[x],res);
}
int main()
{
	n=read();
	for(int i=1;i<=n;++i)
	{
		val[i]=read();
		b[i]=val[i];
		fa[i]=read();
		add(fa[i],i);
	}
	sort(b+1,b+n+1);
	qwq=unique(b+1,b+n+1)-b-1;
	for(int i=1;i<=n;++i)
	val[i]=lower_bound(b+1,b+qwq+1,val[i])-b;	
	dfs(1);
	ans=query(root[1],1,qwq,qwq);
	printf("%d\n",ans);
	return 0;
}

你可能感兴趣的:(数据结构,线段树,二分,离散化,线段树合并)