【模板】Splay

题目链接:
洛谷 P3369 【模板】普通平衡树(Treap/SBT)
BZOJ 3224: Tyvj 1728 普通平衡树


第一次尝试

第一次splay板子是大佬教给我的,全部用指针完成了splay的基本操作。当时的我码力还是不足,调试了半天极其低级的错误,当时的我还把他们记载在下面:

  • [attempt 1] CE 由于构造函数未在结构体里声明
struct node
{
    int val,siz,cnt;
    node *son[2],*fa;
    node(const int &k);     //<-this line.
    int dir(){return fa->son[1]==this;}
    void upd(){siz=son[0]->siz+son[1]->siz+cnt;}
}*nil=new node(0),*RT,*flag;
node::node(const int &k)
{
    val=k,siz=cnt=1,son[0]=son[1]=fa=nil;
}
  • [attempt 2] WA rotate函数错误
void rotate(node *rt,int d)
{
    node *t=rt->son[d^1];
    rt->son[d^1]=t->son[d];
    if(rt->son[d^1]!=nil)rt->son[d^1]->fa=rt;
    t->son[d]=rt;    //<-this line.
    if(rt->fa!=nil)rt->fa->son[rt->dir()]=t;
    t->fa=rt->fa;rt->fa=t;
    rt->upd();t->upd();
    return;
}
  • [attempt 3] WA delete函数错误
void del(int x)
{
    node *l=lower(RT,x),*r=upper(RT,x);
    if(l==nil&&r==nil)
    {
        if(RT->cnt==1)RT=nil;
        else RT->cnt--,RT->siz--;
        return;
    }
    if(l==nil&&r!=nil)
    {
        splay(r,nil);
        if(RT->son[0]->cnt==1)RT->son[0]=nil;
        else RT->son[0]->cnt--,RT->son[0]->siz--;
        RT->upd();
        return;
    }
    if(r==nil)
    {
        splay(l,nil);
        if(RT->son[1]->cnt==1)RT->son[1]=nil;
        else RT->son[1]->cnt--,RT->son[1]->siz--;
        RT->upd();
        return;
    }
    splay(l,nil);
    splay(r,RT);
    node *obj=RT->son[1]->son[0];
    if(obj->cnt==1)RT->son[1]->son[0]=nil;
    else obj->cnt--,obj->siz--;
    RT->son[1]->upd();
    RT->upd();
    return;
}

这里我调试了很久,比如虚拟节点不能update更新子树大小,再比如删除时只更新了临时变量,没有更改其父亲的儿子指针等等。


第二次尝试

这一次的我经历了一年多的沉淀然而还是那么菜,选择了用数组完成splay这个数据结构。

#include
#include
#include
using namespace std;
inline int read(){
  int x=0,y=1;char c=getchar();
  while(!isdigit(c)){if(c=='-')y=-y;c=getchar();}
  while(isdigit(c)){x=x*10+c-'0';c=getchar();}
  return x*y;
}

const int INF = 100000000;
const int MAXN = 100010;
int son[MAXN][2],fa[MAXN];
int siz[MAXN],rev[MAXN],cnt[MAXN],key[MAXN];
int tot,root;

int getd(int x){return son[fa[x]][1] == x;}

int push_up(int x){
  siz[x] = siz[son[x][0]] + siz[son[x][1]] + cnt[x];
  return 0;
}
void push_down(int x){
  if(rev[x]){
    rev[x]^=1,rev[son[x][0]]^=1,rev[son[x][1]]^=1;
    swap(son[x][0],son[x][1]);
  }
}

void rotate(int x){
  int fat=fa[x],gra=fa[fat],dir=getd(x),fad=getd(fat);
  son[fat][dir]=son[x][dir^1];
  fa[son[x][dir^1]]=fat;
  son[x][dir^1]=fat;
  fa[fat]=x;
  if(gra)son[gra][fad]=x;
  fa[x]=gra;
  push_up(fat);
  push_up(x);
}
int splay(int x,int to){
  while(fa[x]!=to){
    if(fa[fa[x]]!=to&&getd(fa[x])==getd(x))rotate(fa[x]);
    rotate(x);
  }
  if(to==0)root=x;
  return 0;
}

void init(){
  key[1]=-INF,key[2]=INF;
  siz[1]=1,siz[2]=2;
  cnt[1]=cnt[2]=1;
  fa[1]=2,son[2][0]=1;
  root=tot=2;
}

int insert(int now,int fat,int dir,int k){
  if(now==0){
    fa[++tot]=fat,son[fat][dir]=tot;
    cnt[tot]=siz[tot]=1;
    key[tot]=k;
    return tot;
  }
  if(key[now]==k){
    siz[now]++;
    cnt[now]++;
    return now;
  }
  if(key[now]>k)return insert(son[now][0],now,0,k)+push_up(now);
  if(key[now]return insert(son[now][1],now,1,k)+push_up(now);
  return 0;
}

int Find(int now,int x){
  if(now==0)return 0;
  if(key[now]==x)return now;
  if(key[now]>x)return Find(son[now][0],x);
  if(key[now]return Find(son[now][1],x);
}

int get_rank(int now,int k){
  if(now==0)return printf("error\n");
  if(key[now]==k)return siz[son[now][0]]+1;
  if(key[now]>k)return get_rank(son[now][0],k);
  if(key[now]return get_rank(son[now][1],k)+siz[now]-siz[son[now][1]];
  return 0;
}

int find_rank(int now,int k){
  if(now==0)return 0;
  const int L=siz[son[now][0]],N=cnt[now],R=siz[son[now][1]];
  if(k<=L)return find_rank(son[now][0],k);
  if(L0);return key[now];}
  if(k>L+N)return find_rank(son[now][1],k-L-N);
}

inline int max_key(int x,int y){key[0]=-INF;return key[x]>key[y]?x:y;}
int prev(int now,int k){
  if(now==0)return 0;
  if(key[now]>=k)return prev(son[now][0],k);
  return max_key(prev(son[now][1],k),now);
}

inline int min_key(int x,int y){key[0]=INF;return key[x]int next(int now,int k){
  if(now==0)return 0;
  if(key[now]<=k)return next(son[now][1],k);
  return min_key(next(son[now][0],k),now);
}

int Delete(int k){
  splay(prev(root,k),0);
  splay(next(root,k),root);
  int t=son[son[root][1]][0];
  if(cnt[t]>1){cnt[t]--;siz[t]--;}
  else son[son[root][1]][0]=0;
  push_up(son[root][1]);
  push_up(root);
}

int main(){
  init();
  int n=read();
  while(n--){
    int opt=read(),x=read();
    switch(opt){
      case 1:{
        splay(insert(root,0,0,x),0);
        break;
      }
      case 2:{
        Delete(x);
        break;
      }
      case 3:{
        printf("%d\n",get_rank(root,x)-1);
        splay(Find(root,x),0);
        break;
      }
      case 4:{
        printf("%d\n",find_rank(root,x+1));
        break;
      }
      case 5:{
        printf("%d\n",key[prev(root,x)]);
        break;
      }
      case 6:{
        printf("%d\n",key[next(root,x)]);
        break;
      }
    }
    //for(int i=1;find_rank(root,i)!=INF;i++){
    //  printf("%d ",find_rank(root,i));
    //}
    //printf("\n");
  }
}

关于splay的删除,网上的版本不一,但经过我的搜索,最快的应该是下面这种:

【模板】Splay_第1张图片

可以说是非常稳了……

【模板】Splay_第2张图片


代码如下:

#include
using namespace std;
struct node
{
    int val,siz,cnt;
    node *son[2],*fa;
    node(const int &k);
    int dir(){return fa->son[1]==this;}
    void upd(){siz=son[0]->siz+son[1]->siz+cnt;}
}*nil=new node(0),*RT,*flag;
node::node(const int &k)
{
    val=k,siz=cnt=1,son[0]=son[1]=fa=nil;
}
void clear()
{
    nil->siz=nil->cnt=0,RT=nil;
    return;
}
void rotate(node *rt,int d)
{
    node *t=rt->son[d^1];
    rt->son[d^1]=t->son[d];
    if(rt->son[d^1]!=nil)rt->son[d^1]->fa=rt;
    t->son[d]=rt;
    if(rt->fa!=nil)rt->fa->son[rt->dir()]=t;
    t->fa=rt->fa;rt->fa=t;
    rt->upd();t->upd();
    return;
}
void splay(node *rt,node *to)
{
    while(rt->fa!=to)
    {
        if(rt->fa->fa!=to&&rt->dir()==rt->fa->dir())
            rotate(rt->fa->fa,rt->dir()^1);
        rotate(rt->fa,rt->dir()^1);
    }
    if(to==nil)RT=rt;
    return;
}
void add(node *&rt,node *fa,int x)
{
    if(rt==nil)
    {
        rt=new node(x);
        rt->fa=fa;
        flag=rt;
        return;
    }
    if(rt->val==x){rt->cnt++;flag=rt;}
    else if(rt->val>x){add(rt->son[0],rt,x);}
    else {add(rt->son[1],rt,x);}
    rt->upd();
    return;
}
void insert(int x)
{
    add(RT,nil,x);
    splay(flag,nil);
    return;
}
node *lower(node *rt,int x)
{
    if(rt==nil)return nil;
    if(rt->val>=x)return lower(rt->son[0],x);
    else{
        node *t=lower(rt->son[1],x);
        return t==nil?rt:t;
    }
}
node *upper(node *rt,int x)
{
    if(rt==nil)return nil;
    if(rt->val<=x)return upper(rt->son[1],x);
    else
    {
        node *t=upper(rt->son[0],x);
        return t==nil?rt:t;
    }
}
void find(node *rt,int x){
    if(rt==nil)return;
    if(rt->val==x){flag=rt;return;}
    if(rt->val>x)find(rt->son[0],x);
    else find(rt->son[1],x);
    return;
}
void del(int x)
{
    node *l=lower(RT,x),*r=upper(RT,x);
    if(l==nil&&r==nil)
    {
        if(RT->cnt==1)RT=nil;
        else RT->cnt--,RT->siz--;
        return;
    }
    if(l==nil&&r!=nil)
    {
        splay(r,nil);
        if(RT->son[0]->cnt==1)RT->son[0]=nil;
        else RT->son[0]->cnt--,RT->son[0]->siz--;
        RT->upd();
        return;
    }
    if(r==nil)
    {
        splay(l,nil);
        if(RT->son[1]->cnt==1)RT->son[1]=nil;
        else RT->son[1]->cnt--,RT->son[1]->siz--;
        RT->upd();
        return;
    }
    splay(l,nil);
    splay(r,RT);
    node *obj=RT->son[1]->son[0];
    if(obj->cnt==1)RT->son[1]->son[0]=nil;
    else obj->cnt--,obj->siz--;
    RT->son[1]->upd();
    RT->upd();
    return;
}
int rand(int x)
{
    node *rt=RT;int res=1;
    while(rt->val!=x)
    {
        if(rt->val>x)rt=rt->son[0];
        else res=res+rt->son[0]->siz+rt->cnt,rt=rt->son[1];
    }
    res+=rt->son[0]->siz;
    return res;
}
int rand(node *rt,int x)
{
    if(rt->son[0]->siz>=x)return rand(rt->son[0],x);
    if(rt->son[0]->siz+rt->cnt>=x)return rt->val;
    return rand(rt->son[1],x-rt->cnt-rt->son[0]->siz);
}
int read()
{
    int x=0,y=1;char c=getchar();
    while(!isdigit(c)){if(c=='-')y=-y;c=getchar();}
    while(isdigit(c))x=x*10+c-'0',c=getchar();
    return x*y;
}
int main()
{
    clear();
    int n=read();
    while(n--)
    {
        int opt=read(),x=read();
        switch(opt)
        {
            case 1:
            {
                insert(x);
                break;
            }
            case 2:
            {
                del(x);
                break;
            }
            case 3:
            {
                printf("%d\n",rand(x));
                break;
            }
            case 4:
            {
                printf("%d\n",rand(RT,x));
                break;
            }
            case 5:
            {
                printf("%d\n",lower(RT,x)->val);
                break;
            }
            case 6:
            {
                printf("%d\n",upper(RT,x)->val);
                break;
            }
        }
    }
    return 0;
}
/*
10
1 1
4 1
1 3
1 4
*/

你可能感兴趣的:(数据结构,平衡树,splay)