bzoj4825/洛谷P3721 单旋 splay

题目分析

有人问起我学会的第一个平衡树是什么。

我说是spaly。

在HNOI2017的考场上学会的。

俗话说的好,双旋的splay,单旋的spaly,不旋的saply,O(1)的asply,那么我们就来用splay做一做这道题。

首先我们手模一发单旋最小值操作。会发现,假如最小值节点是x,那么这个操作就是把x放到根,x的右子树给他原来的父亲当左子树,把原来的根节点给它做右子树。

思考思考就会发现,x的右子树的dfs序应该是连续的,准确的说,以x为根的子树应该是spaly的dfs序从左边开始的一段连续的区间,且这个区间里的所有节点的深度都要大于等于x的深度。

现在我们用一棵splay来维护,splay中节点的顺序就是按照权值排序,然后维护一下每个节点的dep值(深度)和其子树里的最小深度,然后每种操作的方法如下:

1.插入:我们寻找x的前驱和后继,发现要么前驱是后继的父亲,要么后继是前驱的父亲(因为前驱和后继的dfs序一定相邻,所以这两个节点一定相邻),所以新加入节点的深度就是max(dep(前驱),dep(后继))+1。除此之外,就简单地将新节点插入splay中即可。

2.单旋最小/大值:找到x的右/左子树代表的区间长度,首先将所有节点的dep +1,然后将x的右/左子树的节点 dep -1,然后再将x的dep单点赋值成1.

3.删除:删除x节点,将所有节点的深度都-1

这道题最难的地方,果然还是细节处理。我调试了两个半小时。由于每个人的写法不同,不予赘述我错了哪些细节,附赠一个丑丑的数据生成器,加油对拍吧。

代码

#include
using namespace std;
int read() {
	int q=0;char ch=' ';
	while(ch<'0'||ch>'9') ch=getchar();
	while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
	return q;
}
const int N=100005,inf=0x3f3f3f3f;
int m,rt,n;
int s[N][2],f[N],dep[N],v[N],laz[N],mn[N],sz[N];
void up(int x) {
	sz[x]=sz[s[x][0]]+sz[s[x][1]]+1;
	mn[x]=min(min(mn[s[x][0]],mn[s[x][1]]),dep[x]);
}
void pd(int x) {
	if(!laz[x]) return;
	int ls=s[x][0],rs=s[x][1],t=laz[x];
	if(ls) dep[ls]+=t,mn[ls]+=t,laz[ls]+=t;
	if(rs) dep[rs]+=t,mn[rs]+=t,laz[rs]+=t;
	laz[x]=0;
}
int is(int x) {return s[f[x]][1]==x;}
void spin(int x,int &mb) {
	int fa=f[x],g=f[fa],t=is(x);
	if(f[x]==mb) mb=x;
	else s[g][is(fa)]=x;
	f[x]=g,f[fa]=x,f[s[x][t^1]]=fa;
	s[fa][t]=s[x][t^1],s[x][t^1]=fa;
	up(fa),up(x);
}
void splay(int x,int &mb) {
	while(x!=mb) {
		if(f[x]!=mb) {
			if(is(x)^is(f[x])) spin(x,mb);
			else spin(f[x],mb);
		}
		spin(x,mb);
	}
}
int find(int x,int num) {//寻找dfs序第num的节点
	pd(x);
	if(sz[s[x][0]]+1==num) return x;
	if(sz[s[x][0]]>=num) return find(s[x][0],num);
	else return find(s[x][1],num-sz[s[x][0]]-1);
}
void add(int l,int r,int num) {//区间加
	int x=find(rt,l-1),y=find(rt,r+1);
	splay(x,rt),splay(y,s[x][1]);
	laz[s[y][0]]+=num,dep[s[y][0]]+=num,mn[s[y][0]]+=num;
	up(y),up(x);//注意pushup
}
int pre(int x,int num) {//前驱
	if(!x) return 0;
	pd(x);
	if(v[x]<num) {int kl=pre(s[x][1],num);return kl?kl:x;}
	else return pre(s[x][0],num);
}
int nxt(int x,int num) {//后继
	if(!x) return 0;
	pd(x);
	if(v[x]>num) {int kl=nxt(s[x][0],num);return kl?kl:x;}
	else return nxt(s[x][1],num);
}
void ins(int &x,int num,int d,int las) {//插入
	if(!x) {x=++n,f[x]=las,v[x]=num,dep[x]=mn[x]=d,sz[x]=1;return;}
	pd(x);
	if(num<v[x]) ins(s[x][0],num,d,x);
	else ins(s[x][1],num,d,x);
	up(x);
}
int getl(int x,int num) {//获得从左边开始的连续的dep[x]>=num的区间长度
	if(!x) return 0;
	pd(x);
	if(dep[x]>=num&&mn[s[x][0]]>=num) return sz[s[x][0]]+1+getl(s[x][1],num);
	else return getl(s[x][0],num);
}
int getr(int x,int num) {//获得从右边开始的连续的dep[x]>=num的区间长度
	if(!x) return 0;
	pd(x);
	if(dep[x]>=num&&mn[s[x][1]]>=num) return sz[s[x][1]]+1+getr(s[x][0],num);
	else return getr(s[x][1],num);
}
void chan(int x,int num) {//单点修改
	pd(x);
	if(v[x]==num) {mn[x]=dep[x]=1;return;}
	if(num<v[x]) chan(s[x][0],num);
	else chan(s[x][1],num);
	up(x);
}
void del(int x) {//删除
	splay(x,rt);
	if(s[x][0]*s[x][1]==0) rt=s[x][0]+s[x][1],f[rt]=0;
	else {
		int y=s[x][1];
		while(s[y][0]) pd(y),y=s[y][0];
		s[y][0]=s[x][0],f[s[x][0]]=y,rt=s[x][1],f[rt]=0;
		while(y) up(y),y=f[y];//记得pushup
	}
}
int main()
{
	int x,y;
	m=read();
	mn[0]=inf,ins(rt,-inf,inf,0),ins(rt,inf,inf,0);
	while(m--) {
		int bj=read();
		if(bj==1) {
			x=read();int a=pre(rt,x),b=nxt(rt,x);
			a=((a==1||a==2)?0:dep[a]),b=((b==1||b==2)?0:dep[b]);
			printf("%d\n",max(a,b)+1);
			ins(rt,x,max(a,b)+1,0),splay(n,rt);//这个splay用于维护平衡
		}
		if(bj==2||bj==4) {
			x=find(rt,2),printf("%d\n",dep[x]);
			y=min(getl(rt,dep[x]),sz[rt]-1);
			add(2,sz[rt]-1,1),add(2,y,-1);
			chan(rt,v[x]);
		}
		if(bj==3||bj==5) {
			x=find(rt,sz[rt]-1),printf("%d\n",dep[x]);
			y=min(getr(rt,dep[x]),sz[rt]-1);
			add(2,sz[rt]-1,1),add(sz[rt]-y+1,sz[rt]-1,-1);
			chan(rt,v[x]);
		}
		if(bj==4||bj==5) del(x),add(2,sz[rt]-1,-1);
	}
    return 0;
}

数据生成器

#include
using namespace std;
int a[100005],js,n;
void ins() {
	++js;
	int x=rand()%20+1;
	while(a[x]) x=rand()%20+1;
	a[x]=1;printf("1 %d\n",x);
}
int main()
{
	srand(time(NULL));
	n=rand()%10+1,printf("%d\n",n);
	while(n--) {
		if(!js) ins();
		else {
			int bj=rand()%5+1;
			if(bj==1) ins();
			else printf("%d\n",bj);
			if(bj==4||bj==5) --js;
		}
	}
	return 0;
}

你可能感兴趣的:(数据结构)