题目链接
题意:
给你一棵树,每个点有点权,问你最多能选出多少个点,使得所有选出的点中子节点的权值都比父节点小(严格小于)。点数2e5,权值1e9
题解:
首先的一个暴力是用一个树形dp, d p [ x ] [ i ] dp[x][i] dp[x][i]表示点 x x x为根的子树内,最大权值是 i i i时子树内最多选的点数。我们不难发现,随着这个 i i i的增大,最多选出的点数也是单调不降的。于是我们考虑从子节点转移过来, d p [ x ] [ i ] = m a x ( ∑ y ∈ s o n [ x ] d p [ y ] [ i ] , ∑ y ∈ s o n [ x ] d p [ x ] [ i − 1 ] + 1 ) dp[x][i]=max(\sum_{y\in son[x]}dp[y][i],\sum_{y\in son[x]}dp[x][i-1]+1) dp[x][i]=max(∑y∈son[x]dp[y][i],∑y∈son[x]dp[x][i−1]+1)。前一种是不选当前点,后一种是选了当前点,应该不难理解。我们发现其实第二维只与点的大小关系有关,与具体的点权值没有关系,于是我们离散化一下,就可以得到一个 O ( n 2 ) O(n^2) O(n2)的算法。
我们考虑能不能对这个dp进行优化。我们发现,如果对于当前 x x x,你有了每一个 i i i的 ∑ y ∈ s o n [ x ] d p [ y ] [ i ] \sum_{y\in son[x]}dp[y][i] ∑y∈son[x]dp[y][i],那么我们考虑当前点更新答案时可能更新哪些点。我们不难发现,由于 ∑ y ∈ s o n [ x ] d p [ y ] [ i ] \sum_{y\in son[x]}dp[y][i] ∑y∈son[x]dp[y][i]是单调递增的,所以能更新的区间也一定是一段连续的区间。而这个连续的区间的一个端点一定是当前 x x x点的值,显然是不可能更新更小的权值时的答案的。于是我们可以二分一下右端点。我们发现,我们找到这个区间后就只需要维护一个区间加的操作,这个可以线段树来做。而对于整个树形结构,我们可以用一个线段树合并来从儿子到父亲更新答案,线段树的权值 i i i维护的就是 d p [ x ] [ i ] dp[x][i] dp[x][i]的答案。这样就可以做到 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),是可以通过本题的。
然而由于空间复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)的,于是我因为数组开得过大而TLE掉了。。。
代码:
#include
#include
#include
#include
#include
#include
using namespace std;
int n,fa[200010],hed[200010],cnt,num,val[200010],ans;
int b[200010],root[200010],qwq;
struct node
{
int to,next;
}a[400010];
struct tree
{
int l,r,mx;
}tr[4000100];
inline int read()
{
int x=0;
char s=getchar();
while(s>'9'||s<'0')
s=getchar();
while(s>='0'&&s<='9')
{
x=x*10+s-'0';
s=getchar();
}
return x;
}
inline void add(int from,int to)
{
a[++cnt].to=to;
a[cnt].next=hed[from];
hed[from]=cnt;
}
inline void merge(int &l,int r)
{
if(!l||!r)
{
l=l+r;
return;
}
tr[l].mx=tr[l].mx+tr[r].mx;
merge(tr[l].l,tr[r].l);
merge(tr[l].r,tr[r].r);
}
inline int query(int rt,int l,int r,int x)
{
if(!rt||!x)
return 0;
if(l==r)
return tr[rt].mx;
int mid=(l+r)>>1,res=tr[rt].mx;
if(x<=mid)
res+=query(tr[rt].l,l,mid,x);
else
res+=query(tr[rt].r,mid+1,r,x);
return res;
}
inline void update(int &rt,int l,int r,int le,int ri)
{
if(!rt)
rt=++num;
if(le<=l&&r<=ri)
{
tr[rt].mx++;
return;
}
int mid=(l+r)>>1;
if(le<=mid)
update(tr[rt].l,l,mid,le,ri);
if(mid+1<=ri)
update(tr[rt].r,mid+1,r,le,ri);
}
inline void dfs(int x)
{
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
dfs(y);
merge(root[x],root[y]);
}
int ji=query(root[x],1,qwq,val[x]-1)+1;
if(ji<=query(root[x],1,qwq,val[x]))
return;
int l=val[x],r=qwq,mid,res=val[x];
while(l<=r)
{
mid=(l+r)>>1;
if(query(root[x],1,qwq,mid)<ji)
{
res=mid;
l=mid+1;
}
else
r=mid-1;
}
update(root[x],1,qwq,val[x],res);
}
int main()
{
n=read();
for(int i=1;i<=n;++i)
{
val[i]=read();
b[i]=val[i];
fa[i]=read();
add(fa[i],i);
}
sort(b+1,b+n+1);
qwq=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;++i)
val[i]=lower_bound(b+1,b+qwq+1,val[i])-b;
dfs(1);
ans=query(root[1],1,qwq,qwq);
printf("%d\n",ans);
return 0;
}