最近看到有一种不用旋转的treap,好像还可以持久化,于是就学了一下。
treap就是tree+heap。它的每个节点的权值data满足排序二叉树的性质,随机权值key满足堆的性质。由于key是随机的所以它大致是平衡的。
不基于旋转的treap有两个基本操作:
merge(a,b):返回一个treap,包含a,b两个treap中的所有节点,但要保证b中所有节点权值都大于等于a。
split(a,n)返回两个treap l,r。其中l中包含treap a中的前n个节点,r中包含treap a中的剩余节点。
这两个操作的实现都很简单(这里我们维护小根堆的性质):
merge(a,b):若a的key< b的key则将a的右儿子变为merge(a的右儿子,b)。
否则将b的左儿子变为merge(a,b的左儿子)。
split(a,n):若a左子树的size(记为m)=n则返回a的左子树,a和a的右子树。若m=n-1则返回a的左子树和a,a的右子树。否则若m>n则设{l,r}为split(a的左子树,n)将a的左子树设为r,返回l,a。若m< n-1则设{l,r}为split(a的右子树,n-m-1)将a的右儿子设为l,返回a,r。
有了这两操作我们就可以实现插入和删除了。
插入x:找到x所在位置,将其split开,在合并l,x与x,r。
删除x:找到x的位置,将x与其前后位置都split开,在合并另外两部份。
其余操作和不同平衡树一样。
由于不基于旋转,我们可以将其可持久化,或者套上其他数据结构。
bzoj3224`
一道裸的平衡树题。就拿它做模板题吧。
代码:
#include<cstdio>
#include<algorithm>
using namespace std;
#define maxn 100010
#define mp make_pair
typedef pair<int,int> par;
struct node{int ls,rs,data,key,size;}t[maxn];
int op,n,x,root,num;
void updata(int x){
t[x].size=t[t[x].ls].size+t[t[x].rs].size+1;
}
par split(int a,int n){
if(n==0)return mp(0,a);
int ls=t[a].ls,rs=t[a].rs;
if(n==t[ls].size) return t[a].ls=0,updata(a),mp(ls,a);
if(n==t[ls].size+1) return t[a].rs=0,updata(a),mp(a,rs);
if(n<t[ls].size){
par tmp=split(ls,n);
return t[a].ls=tmp.second,updata(a),mp(tmp.first,a);
}
par tmp=split(rs,n-t[ls].size-1);
return t[a].rs=tmp.first,updata(a),mp(a,tmp.second);
}
int merge(int a,int b){
if(a==0||b==0)return a+b;
if(t[a].key<t[b].key) return t[a].rs=merge(t[a].rs,b),updata(a),a;
else return t[b].ls=merge(a,t[b].ls),updata(b),b;
}
int rank(int x,int k){
int ans=0,tmp=(int)1e9;
while(k){
if(x==t[k].data)tmp=min(tmp,ans+t[t[k].ls].size+1);
if(x>t[k].data)ans+=t[t[k].ls].size+1,k=t[k].rs;
else k=t[k].ls;
}
return tmp==(int)1e9?ans:tmp;
}
int find(int x,int k){
while(true){
if(t[t[k].ls].size==x-1)return t[k].data;
if(t[t[k].ls].size>x-1)k=t[k].ls;
else x=x-t[t[k].ls].size-1,k=t[k].rs;
}
}
int pre(int x,int k){
int ans=-(int)1e9;
while(k){
if(t[k].data<x)ans=max(ans,t[k].data),k=t[k].rs;
else k=t[k].ls;
}
return ans;
}
int neg(int x,int k){
int ans=(int)1e9;
while(k){
if(t[k].data>x)ans=min(ans,t[k].data),k=t[k].ls;
else k=t[k].rs;
}
return ans;
}
void insert(int x){
int k=rank(x,root);
par tmp=split(root,k);
t[++num].data=x;
t[num].key=rand();
t[num].size=1;
root=merge(tmp.first,num);
root=merge(root,tmp.second);
}
void del(int x){
int k=rank(x,root);
par t1=split(root,k);
par t2=split(t1.first,k-1);
root=merge(t2.first,t1.second);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&op,&x);
if(op==1)insert(x);
else if(op==2)del(x);
else if(op==3)printf("%d\n",rank(x,root));
else if(op==4)printf("%d\n",find(x,root));
else if(op==5)printf("%d\n",pre(x,root));
else printf("%d\n",neg(x,root));
}
return 0;
}