给定一颗二叉树。
每次删除一个叶子节点,并要求输出所有重链指向结点编号之和。
初始也要输出一次。
初始时,如果一个结点的两个儿子大小相同,则选择左儿子为重儿子。
一次删除操作后,如果一个结点的两个儿子大小相同,则不改变原先的重儿子选择。
如果删除一个叶子,重儿子指向可能会发生改变的一定是该结点到根节点的路径上的结点。
其中表现为——路径上的重边可能会变成轻边。
这个就不好处理……
如果是增加一个叶子,会有什么影响?
其中表现为——路径上的轻边可能会变成重边。
这个就好处理了!因为轻边数量<=log n。
只要沿着重链往上跳,每遇上一条轻边就看看它会不会变成重边即可。
可本题是删除……
别怕!我们先把结点插完,然后倒着处理。就把原来的删除操作改成添加操作了。
因为要动态维护子树size,并且要修改轻重边,所以我们用动态树链剖分(即用splay维护重链,注意这不是LCT,因为重链是严格意义上的重链,而且没有access操作)。
万一出现了两个子树大小相同,我们怎么知道哪个是重儿子呢?因为我们是倒着做的……
其实也很简单。
例如结点i的两个儿子j和k此时大小相同。
下一个插入进以j为根的子树中的结点的插入时间为t1。
下一个插入进以k为根的子树中的结点的插入时间为t2。
我们定义插入时间小的会被率先插入。
如果 t1<t2 ,显然j是重儿子。
如果 t1>t2 ,显然k是重儿子。
如果 t1=t2 ?谁是左儿子谁就是重儿子。
然后如何获得t1和t2呢?
因为已经知道最终树的形态,以及每个结点的插入时间。
那么,我们就是要查询子树最小值。
dfs序就可以搞了。注意每插入一个结点后要把其的插入时间改为无穷。
我的代码的实现,每个结点其实保留的是删除时间而非插入时间。
然后本题解决,然而我spaly大法被卡常最后90分。
有许多细节……
#include
#include
#include
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
typedef long long ll;
const int maxn=200000+10;
int size[maxn],father[maxn],pp[maxn],zx[maxn],tree[maxn][2],add[maxn];
int fa[maxn],a[maxn],h[maxn],go[maxn],next[maxn],sta[maxn],del[maxn];
int big[maxn],dfn[maxn],num[maxn*5];
int left[maxn],right[maxn],lea[maxn];
bool bz[maxn],lch[maxn];
int i,j,k,l,t,n,m,tot,top,cnt,root;
ll ans,s[maxn];
void insert(int x,int y){
fa[y]=x;
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
int pd(int x){
return x==tree[father[x]][1];
}
void rotate(int x){
int y=father[x],z=pd(x);
father[x]=father[y];
if (father[y]) tree[father[y]][pd(y)]=x;
tree[y][z]=tree[x][1-z];
if (tree[x][1-z]) father[tree[x][1-z]]=y;
tree[x][1-z]=y;
father[y]=x;
if (pp[y]) pp[x]=pp[y],pp[y]=0;
}
void mark(int x,int y){
size[x]+=y;
add[x]+=y;
}
void clear(int x){
if (add[x]){
if (tree[x][0]) mark(tree[x][0],add[x]);
if (tree[x][1]) mark(tree[x][1],add[x]);
add[x]=0;
}
}
void remove(int x,int y){
int top=0;
while (x!=y){
sta[++top]=x;
x=father[x];
}
while (top){
clear(sta[top]);
top--;
}
}
void splay(int x,int y){
remove(x,y);
while (father[x]!=y){
if (father[father[x]]!=y)
if (pd(x)==pd(father[x])) rotate(father[x]);else rotate(x);
rotate(x);
}
}
void link(int x,int y){
pp[x]=y;
while (x){
splay(x,0);
size[x]++;
if (tree[x][0]) mark(tree[x][0],1);
x=pp[x];
}
}
void change(int p,int l,int r,int a,int b){
if (l==r){
num[p]=b;
return;
}
int mid=(l+r)/2;
if (a<=mid) change(p*2,l,mid,a,b);else change(p*2+1,mid+1,r,a,b);
num[p]=max(num[p*2],num[p*2+1]);
}
int query(int p,int l,int r,int a,int b){
if (l==a&&r==b) return num[p];
int mid=(l+r)/2;
if (b<=mid) return query(p*2,l,mid,a,b);
else if (a>mid) return query(p*2+1,mid+1,r,a,b);
else return max(query(p*2,l,mid,a,mid),query(p*2+1,mid+1,r,mid+1,b));
}
int nexttime(int x){
return query(1,1,n,dfn[x],dfn[x]+big[x]-1);
}
void solve(int x){
int y,z,t,j,k;
while (x!=root){
splay(x,0);
y=pp[x];
splay(y,0);
ans-=(ll)zx[y];
clear(y);
father[tree[y][1]]=0;
pp[tree[y][1]]=y;
tree[y][1]=0;
if (!zx[y]) zx[y]=x;
else{
if (2*size[x]>size[y]-1) zx[y]=x;
else if (2*size[x]==size[y]-1){
t=h[y];
while (t){
if (go[t]!=x){
z=go[t];
break;
}
t=next[t];
}
j=nexttime(x);k=nexttime(z);
if (j>k) zx[y]=x;
else if (j==k&&lch[x]) zx[y]=x;
}
}
ans+=(ll)zx[y];
splay(zx[y],0);
tree[y][1]=zx[y];
father[zx[y]]=y;
pp[zx[y]]=0;
x=y;
splay(x,0);
while (tree[x][0]) x=tree[x][0];
}
}
void dfs(int x){
int t=h[x];
while (t){
if (!bz[go[t]]){
link(go[t],x);
dfs(go[t]);
}
t=next[t];
}
splay(x,0);
t=h[x];
while (t){
if (!bz[go[t]]){
splay(go[t],0);
if (2*size[go[t]]>size[x]-1) zx[x]=go[t];
else if (2*size[go[t]]==size[x]-1&&lch[go[t]]) zx[x]=go[t];
}
t=next[t];
}
ans+=(ll)zx[x];
}
void dg(int x){
dfn[x]=++cnt;
int t=h[x];
while (t){
dg(go[t]);
big[x]+=big[go[t]];
t=next[t];
}
big[x]++;
}
void brute(){
fo(i,1,n) size[i]=big[i];
fo(i,1,n){
if (fa[i]){
if (lch[i]) left[fa[i]]=i;else right[fa[i]]=i;
lea[fa[i]]++;
}
}
fo(i,1,n){
if (big[left[i]]>=big[right[i]]) zx[i]=left[i];else zx[i]=right[i];
ans+=(ll)zx[i];
}
printf("%lld\n",ans);
scanf("%d",&m);
fo(i,1,m){
scanf("%d",&j);
lea[fa[j]]--;
while (j!=root){
k=fa[j];
big[j]--;
if (left[k]==j) l=right[k];else l=left[k];
ans-=(ll)zx[k];
if (big[j]if (!lea[k]) zx[k]=0;
ans+=(ll)zx[k];
j=k;
}
big[root]--;
printf("%lld\n",ans);
}
}
int read(){
int x=0;
char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x;
}
void now(){
printf("%d\n",clock());
}
int main(){
freopen("heavy.in","r",stdin);freopen("heavy.out","w",stdout);
n=read();
fo(i,1,n){
j=read();k=read();
if (j){
insert(i,j);
lch[j]=1;
}
if (k) insert(i,k);
}
//now();
fo(i,1,n)
if (!fa[i]){
root=i;
break;
}
dg(root);
//now();
if (n<=1000){
brute();
fclose(stdin);fclose(stdout);
return 0;
}
m=read();
fo(i,1,m){
a[i]=read();
bz[a[i]]=1;
del[a[i]]=i;
}
//now();
fo(i,1,n) change(1,1,n,dfn[i],del[i]);
//now();
dfs(root);
//now();
s[++top]=ans;
fd(i,m,1){
j=a[i];
change(1,1,n,dfn[j],0);
link(j,fa[j]);
solve(j);
s[++top]=ans;
//printf("%d\n",clock());
}
while (top){
printf("%lld\n",s[top]);
top--;
}
//now();
fclose(stdin);fclose(stdout);
return 0;
}