给定一棵 n 个节点的二叉树,对其进行轻重路径剖分, size 相同则优先选择左儿子。
有 q 个操作,每次会删除一个点 x (有持续影响),要求动态维护轻重路径剖分(如果 size 相同优先保留原本剖分方案)。
你需要输出 q+1 个数,第一个表示删点前的重边指向节点编号之和,后面 q 个都是删点后的重边指向节点编号和。
1≤n,q≤200000
首先显然答案为所有点编号之和减去重链链顶节点编号之和。
每次操作都会使 size 发生绝对值为 1 的变动,因此如果某个点重边指向变化,那么原来其左右儿子 size 一定相等。因此每次操作最多改变 log2n 条重边指向。
考虑使用树链剖分维护链剖,使用数据结构维护区间内每个点的重儿子与轻儿子 size 之差 v 的最小值,以及是否存在某个点为重链顶 p 。
每次删除点 x ,我们就找到 x 到根的路径。对于路径上的点,如果该点是链顶(你可以使用数据结构维护的信息找到链顶)的父亲,那么其 v 加 1 ,其余点 v 都加 −1 ,这个使用区间加标记维护即可。
然后我们找到路径上所有 v<0 的点,这些点的重边指向都要发生变化。我们将它们 v 取反,将重儿子的 p 设置为 true ,轻儿子的 p 设置为 false (更改链顶信息),同时更新一下答案即可。
注意特判删除点,如果其兄弟已被删除或者本来就不存在,那么答案还要将其减去。
还有就是怎么找到符合条件( v<0 或 p=true )的点呢?在这里我们是先找到 DFS 序最大的,然后再往其左边剩余区间查找。我们先在数据结构上找到对应的最右的区间,然后如果该区间的 v 或 p 满足要求,我们就递归下去,否则到左边的下一个区间找。
时间复杂度:单次更新所有 v 和更新剖分复杂度都是 O((log2n)2) 的,因此总的复杂度就是 O(q(log2n)2) ,时限3000ms,强行碾过去。
听Werkeytom_FTD说使用splay维护重链信息的伪 LCT 可以将复杂度降至 O(qlog2n) ,于是他考场上辛辛苦苦敲出来4000+b的代码被卡常成了90分(悲伤的故事)。
因此,最后祝你,常数要写好,不要爆了,再见。
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
using namespace std;
typedef long long LL;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch))
{
if (ch=='-')
f=-1;
ch=getchar();
}
while (isdigit(ch))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int N=200500;
LL ans;
int son[N][2],size[N],fa[N],DFN[N],tra[N],prt[N],brt[N],det[N],hea[N];
int n,q,idx,rt;
struct segment_tree
{
int tag[N<<2],v[N<<2];
bool P[N<<2];
void add(int x,int edit)
{
tag[x]+=edit;
v[x]+=edit;
}
void clear(int x,int l,int r)
{
if (l==r)
return;
add(x<<1,tag[x]),add(x<<1|1,tag[x]);
tag[x]=0;
}
void update(int x)
{
v[x]=min(v[x<<1],v[x<<1|1]);
P[x]=P[x<<1|1]||P[x<<1];
}
void change(int x,int st,int en,int l,int r,int edit)
{
if (st>en)
return;
clear(x,l,r);
if (st==l&&en==r)
{
add(x,edit);
return;
}
int mid=l+r>>1;
if (en<=mid)
change(x<<1,st,en,l,mid,edit);
else
if (mid+1<=st)
change(x<<1|1,st,en,mid+1,r,edit);
else
change(x<<1,st,mid,l,mid,edit),change(x<<1|1,mid+1,en,mid+1,r,edit);
update(x);
}
void R(int x,int y,int l,int r)
{
clear(x,l,r);
if (l==r)
{
P[x]^=1;
return;
}
int mid=l+r>>1;
if (y<=mid)
R(x<<1,y,l,mid);
else
R(x<<1|1,y,mid+1,r);
update(x);
}
void Rn(int x,int y,int l,int r)
{
clear(x,l,r);
if (l==r)
{
v[x]*=-1;
return;
}
int mid=l+r>>1;
if (y<=mid)
Rn(x<<1,y,l,mid);
else
Rn(x<<1|1,y,mid+1,r);
update(x);
}
int getC(int x,int l,int r)
{
if (v[x]>=0)
return 0;
clear(x,l,r);
if (l==r)
return l;
int mid=l+r>>1;
int ret=getC(x<<1|1,mid+1,r);
if (ret)
return ret;
return getC(x<<1,l,mid);
}
int getc(int x,int st,int en,int l,int r)
{
if (st>en)
return 0;
clear(x,l,r);
if (v[x]>=0)
return 0;
if (st==l&&en==r)
return getC(x,l,r);
int mid=l+r>>1;
if (en<=mid)
return getc(x<<1,st,en,l,mid);
else
if (mid+1<=st)
return getc(x<<1|1,st,en,mid+1,r);
else
{
int ret=getc(x<<1|1,mid+1,en,mid+1,r);
if (ret)
return ret;
return getc(x<<1,st,mid,l,mid);
}
}
int getP(int x,int l,int r)
{
if (!P[x])
return 0;
clear(x,l,r);
if (l==r)
return l;
int mid=l+r>>1;
int ret=getP(x<<1|1,mid+1,r);
if (ret)
return ret;
return getP(x<<1,l,mid);
}
int getp(int x,int st,int en,int l,int r)
{
if (st>en)
return 0;
clear(x,l,r);
if (st==l&&en==r)
return getP(x,l,r);
int mid=l+r>>1;
if (en<=mid)
return getp(x<<1,st,en,l,mid);
else
if (mid+1<=st)
return getp(x<<1|1,st,en,mid+1,r);
else
{
int ret=getp(x<<1|1,mid+1,en,mid+1,r);
if (ret)
return ret;
return getp(x<<1,st,mid,l,mid);
}
}
void build(int x,int l,int r)
{
if (l==r)
{
int u=tra[l];
P[x]=false;
if (prt[u]==u)
{
P[x]=true;
ans-=u;
}
v[x]=abs(size[son[u][0]]-size[son[u][1]]);
return;
}
int mid=l+r>>1;
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
update(x);
}
}t;
void dfs(int x)
{
if (!x)
return;
dfs(son[x][0]),dfs(son[x][1]);
size[x]=size[son[x][0]]+size[son[x][1]]+1;
if (size[son[x][0]]>=size[son[x][1]])
hea[x]=son[x][0];
else
hea[x]=son[x][1];
}
void build(int x,int PRT)
{
if (!x)
return;
prt[x]=PRT;
DFN[x]=++idx;
tra[idx]=x;
build(hea[x],PRT),build(brt[hea[x]],brt[hea[x]]);
}
int top[N];
void resize(int x)
{
int u=x;
top[0]=0;
while (x)
{
int head=t.getp(1,DFN[prt[x]],DFN[x],1,idx),las=DFN[x];
while (head)
{
int f=fa[tra[head]];
top[++top[0]]=f;
las=head-1;
head=t.getp(1,DFN[prt[x]],head-1,1,idx);
}
x=fa[prt[x]];
}
x=fa[u];
int cur=1;
while (x)
{
int las=DFN[x];
while (cur<=top[0]&&prt[top[cur]]==prt[x])
{
t.change(1,DFN[top[cur]]+1,las,1,idx,-1);
t.change(1,DFN[top[cur]],DFN[top[cur]],1,idx,1);
las=DFN[top[cur]]-1;
cur++;
}
t.change(1,DFN[prt[x]],las,1,idx,-1);
x=fa[prt[x]];
}
}
void redec(int x)
{
x=fa[x];
while (x)
{
int H=t.getc(1,DFN[prt[x]],DFN[x],1,idx);
while (H)
{
t.Rn(1,H,1,idx);
t.R(1,DFN[hea[tra[H]]],1,idx);
(ans-=hea[tra[H]])+=brt[hea[tra[H]]];
hea[tra[H]]=brt[hea[tra[H]]];
t.R(1,DFN[hea[tra[H]]],1,idx);
H=t.getc(1,DFN[prt[x]],H-1,1,idx);
}
x=fa[prt[x]];
}
}
int main()
{
freopen("heavy.in","r",stdin);
freopen("heavy.out","w",stdout);
ans=0;
n=read();
for (int i=1;i<=n;i++)
{
ans+=i;
son[i][0]=read(),son[i][1]=read();
if (son[i][0]) brt[son[i][0]]=son[i][1],fa[son[i][0]]=i;
if (son[i][1]) brt[son[i][1]]=son[i][0],fa[son[i][1]]=i;
}
for (int i=1;i<=n;i++)
if (!fa[i])
{
rt=i;
break;
}
dfs(rt);
build(rt,rt);
t.build(1,1,idx);
printf("%lld\n",ans);
q=read();
while (q--)
{
int x=read();
resize(x);
redec(x);
if (det[brt[x]]||!brt[x])
ans-=x;
det[x]=true;
printf("%lld\n",ans);
}
fclose(stdin);
fclose(stdout);
return 0;
}