暑假学习日记:Splay树

    从昨天开始我就想学这个伸展树了,今天花了一个上午2个多小时加下午2个多小时,学习了一下伸展树(Splay树),学习的时候主要是看别人博客啦~发现下面这个博客挺不错的http://zakir.is-programmer.com/posts/21871.html.在里面有连接到《运用伸展树解决数列维护问题》的文章,里面对伸展树的旋转操作讲得很仔细,而且也讲清楚了伸展树是怎么样维护一个数列的,一开始我是小白,觉得树和数列根本没什么关系,但看了之后就会明白,实际上树上的结点是维护该结点的值的,而这个值是原来数列里的哪一项呢?如果该结点对应的中序遍历数k,那么就是对应原数列a中的a[k]这一项.理解了这个之后我就豁然开朗了,要提取一个区间[a,b],实际上只需要将a-1,Splay为根,b+1Splay到根下的右儿子,则根下的右儿子的左儿子就是[a,b]这个区间,这是由平衡树,左小右大的性质决定的.

    所以无论做什么,首先是将该区间提取出来,然后对对应结点做就好了.问题是有时a-1不存在,b+1也不存在,所以一开始人为的做两个头尾的结点.而且很多时候为了避免对NULL的特殊处理,我们会构造一个实的null,让它的sz=0;sum=0;这样就不会影响一些情况的处理

    伸展树的特性是可以反转,注意到,一棵树,如果我们将它的每个结点的左右儿子都互换一次,它的中序遍历就刚好是原来的中序遍历倒过来,利用这个性质可以实现序列反转.而且还可以添加,假如要添加一个串{b1,b2,b3,b4..}在ak之后的位置,首先调出[ak,ak+1]这个区间,然后将{b1,b2...}建一棵伸展树,然后将结点粘在root->ch[1]的左儿子上即可.删除则是同理.我我还可以提取一个区间出来,反转,再加到我想加的地方.操作都是类似的.

    伸展树的优势除了它支持上面的操作外,它还兼容线段树的add,set操作,同样也是每个结点存lazy标记就可以了,然后写一个类似线段树的pushDown,pushUp,维护好区间的信息就可以了~

    下面给出的代码很大程度上(90%)是从上面网站的代码上copy下来的,将它改成自己的习惯的变量名,然后自己多写了一个add标记,原来的代码还能求最大子段和,但加了add之后再求就有点麻烦了,所以就删掉了原本维护最大子段和的代码,自己写了个驱动程序,调了一下感觉还行.

    代码的参数设置可能会不同,像add()函数传的是从哪个位置(pos),加多少个(tot),大可直接写成l,r,传参数的姿势不同罢了,但注意的是,当要在l位置开始加的时候,传进去的是l+1,是因为前面的头指针占了一位,看到输出之后就大概明白为什么要加1了.对了,因为区间的标记的lazy的,所以直接中序遍历得不出实际的序列(因为有些标记没往下传),所以写了个maintain()先把所有标记下传,实际上是不需要的,随用随查就好了,这么写是为了方便debug~

 

#include<iostream>

#include<cstdio>

#include<cstring>

#include<string>

#include<algorithm>

#include<vector>

#define INF 0x3fffffff

#define maxn 500000

using namespace std;



struct Node

{

    Node *pre,*ch[2];

    bool rev,cov; // 结点翻转标记与cover标记

    int add; // 结点add标记(表示加了多少)

    int sz,val,sum; // 结点的size,保存的值,以及以该结点为子树的和

}*root,N[maxn],*null; // 定义了根的指针,人手写的空的指针,以及结点数组N

Node *stack[maxn]; // 用一个栈来回收用过的指针,这是学到的新姿势,这样的话在构造新的结点的时候可以不用一直idx++

int top,idx; // 栈顶指针,以及数组idx指针

int a[maxn+20]; // 用来构造伸展树的数组



Node *addNode(int val) // 产生新结点

{

    Node *p;

    if(top) p=stack[--top]; // 首先从回收栈里取

    else p=&N[idx++]; // 没有的话从N里面取

    //初始化

    p->rev=p->cov=false; 

    p->sz=1;

    p->sum=p->val=val;

    p->ch[0]=p->ch[1]=p->pre=null;

    return p;

}



void Recycle(Node *p) // 递归回收删除掉的指针,这是用来节省空间的

{

    if(p->ch[0]!=null) Recycle(p->ch[0]);

    if(p->ch[1]!=null) Recycle(p->ch[1]);

    stack[++top]=p;

}



void pushDown(Node *p) // 核心函数,用来处理标记的

{

    if(p==null||!p) return; // 遇到空指针返回

    if(p->rev)  // 先处理反转标记

    {

        swap(p->ch[0],p->ch[1]); // 交换子树

        if(p->ch[0]!=null) p->ch[0]->rev^=1; // 标记下传

        if(p->ch[1]!=null) p->ch[1]->rev^=1; // 标记下传

        p->rev=false; // 标记取消

    }

    if(p->cov) //下面的cov和add标记的更新与下传与线段树相同

    {

        if(p->ch[0]!=null){

            p->ch[0]->val=p->val;

            p->ch[0]->sum=p->val*p->ch[0]->sz;

            p->ch[0]->cov=true;

            p->ch[0]->add=0;

        }

        if(p->ch[1]!=null){

            p->ch[1]->val=p->val;

            p->ch[1]->sum=p->val*p->ch[1]->sz;

            p->ch[1]->cov=true;

            p->ch[1]->add=0;

        }

        p->cov=false;

    }

    if(p->add)

    {

        if(p->ch[0]!=null){

            p->ch[0]->val+=p->add;

            p->ch[0]->sum+=p->ch[0]->sz*p->add;

            p->ch[0]->add+=p->add;

        }

        if(p->ch[1]!=null){

            p->ch[1]->val+=p->add;

            p->ch[1]->sum+=p->ch[1]->sz*p->add;

            p->ch[1]->add+=p->add;

        }

        p->add=0;

    }

}



void pushUp(Node *p) // 核心函数,维护信息

{

    if(p==null) return;

    pushDown(p); 

    pushDown(p->ch[0]);

    pushDown(p->ch[1]);

    p->sz=p->ch[0]->sz+p->ch[1]->sz+1;

    p->sum=p->val+p->ch[0]->sum+p->ch[1]->sum;

}



void rotate(Node *x,int c) // Splay树的旋转函数,标准姿势

{

    Node *y=x->pre;

    pushDown(y);pushDown(x);

    y->ch[c^1]=x->ch[c];

    if(x->ch[c]!=null)

        x->ch[c]->pre=y;

    x->pre=y->pre;

    if(y->pre!=null)

        if(y->pre->ch[0]==y)

            y->pre->ch[0]=x;

        else

            y->pre->ch[1]=x;

    x->ch[c]=y;y->pre=x;

    if(y==root) root=x;

    pushUp(y);

}



void Splay(Node *x,Node *f) // 将x结点转到f下

{

    pushDown(x);

    while(x->pre!=f)

    {

        Node *y=x->pre,*z=y->pre;

        if(x->pre->pre==f)

            rotate(x,x->pre->ch[0]==x);

        else

        {

            if(z->ch[0]==y){

                if(y->ch[0]==x) {rotate(y,1);rotate(x,1);}

                else {rotate(x,0);rotate(x,1);}

            }

            else{

                if(y->ch[1]==x) {rotate(y,0),rotate(x,0);}

                else {rotate(x,1),rotate(x,0);}

            }

        }

    }

    pushUp(x);

}



Node *select(int kth) // 选出第k个点,返回对应结点

{

    int tmp;

    Node *t=root;

    while(1){

        pushDown(t);

        tmp=t->ch[0]->sz;

        if(tmp+1==kth) break;

        if(kth<=tmp) {t=t->ch[0];}

        else { kth-=tmp+1;t=t->ch[1];}

    }

    return t;

}



Node *build(int L,int R) // 建树,有点像线段树

{

    if(L>R) return null;

    int M=(L+R)>>1;

    Node *p=addNode(a[M]);

    p->ch[0]=build(L,M-1);

    if(p->ch[0]!=null){

        p->ch[0]->pre=p;

    }

    p->ch[1]=build(M+1,R);

    if(p->ch[1]!=null){

        p->ch[1]->pre=p;

    }

    pushUp(p);

}



void remove(int pos,int tot) // 从pos位置开始,删除tot个(包括pos)

{

    Splay(select(pos-1),null);

    Splay(select(pos+tot),root);

    if(root->ch[1]->ch[0]!=null){

        Recycle(root->ch[1]->ch[0]);

        root->ch[1]->ch[0]=null;

    }

    pushUp(root->ch[1]);pushUp(root);

    Splay(root->ch[1],null);

}



void insert(int pos,int tot) // 添加,插的是一个数组的时候,要在数组a里面建一颗树,即a[1~N]是要插的数

{

    Node *troot=build(1,tot);

    Splay(select(pos),null);

    Splay(select(pos+1),root);

    root->ch[1]->ch[0]=troot;

    troot->pre=root->ch[1];

    pushUp(root->ch[1]);pushUp(root);

    Splay(troot,null); 

}



void reverse(int pos,int tot) // 从pos开始翻转tot个

{

    Splay(select(pos-1),null);

    Splay(select(pos+tot),root);

    if(root->ch[1]->ch[0]!=null)

    {

        root->ch[1]->ch[0]->rev^=1;

        Splay(root->ch[1]->ch[0],null);

    }

}



void set(int pos,int tot,int c) // 从pos开始将tot个设置为c

{

    Splay(select(pos-1),null);

    Splay(select(pos+tot),root);

    root->ch[1]->ch[0]->val=c;

    root->ch[1]->ch[0]->sum=root->ch[1]->ch[0]->sz*c;

    root->ch[1]->ch[0]->cov=true;

    Splay(root->ch[1]->ch[0],null);

}



void add(int pos,int tot,int c) // 从pos开始将tot个加c

{

    Splay(select(pos-1),null);

    Splay(select(pos+tot),root);

    root->ch[1]->ch[0]->val+=c;

    root->ch[1]->ch[0]->sum+=c*root->ch[1]->ch[0]->sz;

    root->ch[1]->ch[0]->add+=c;

    Splay(root->ch[1]->ch[0],null);

}



int query(int pos,int tot) // 求pos开始tot个的和

{

    Splay(select(pos-1),null);

    Splay(select(pos+tot),root);

    return root->ch[1]->ch[0]->sum;

}



void init() // 初始化函数

{

    idx=top=0; // idx,top归零

    null=addNode(-INF); // 初始化空指针

    null->sz=null->sum=0; // 记住sz和sum一定要设为0

    root=addNode(-INF); // 初始化根指针

    root->sum=0;

    Node *p;

    p=addNode(-INF); // 初始化"树尾"的指针

    root->ch[1]=p;

    p->pre=root;

    p->sum=0;

    pushUp(root->ch[1]);

    pushUp(root);

}

//下面三个函数是调试的时候用的

void maintain(Node *p) // 因为标记是lazy的,所以先将所有标记都下传好

{

    pushDown(p);

    if(p->ch[0]!=null) maintain(p->ch[0]);

    if(p->ch[1]!=null) maintain(p->ch[1]);

}

void dfs(Node *x) // 中序遍历

{

    if(x==null) return;

    dfs(x->ch[0]);

    printf("%d ",x->val);

    dfs(x->ch[1]);

}

void print() // 打印

{

    maintain(root); 

    dfs(root);

    puts("");

}



int main()

{

    int n,m;

    while(cin>>n)

    {

        for(int i=1;i<=n;i++){

            scanf("%d",&a[i]);

        }

        init();

        Node *troot=build(1,n); // 从a数组建一颗Splay树

        root->ch[1]->ch[0]=troot; // 让它和init()里的root,p连上

        troot->pre=root->ch[1]; 

        pushUp(root->ch[1]); // 维护相关信息

        pushUp(root->ch[0]);

        cin>>m;

        int o,l,r,v;

        //支持六种操作,区间add,区间set,区间反转,区间删除,区间添加,区间和

        while(m--)

        {

            scanf("%d",&o);

            if(o==1){

                scanf("%d%d%d",&l,&r,&v);add(l+1,r-l+1,v);print();

            }

            else if(o==2){

                scanf("%d%d%d",&l,&r,&v);set(l+1,r-l+1,v);print();

            }

            else if(o==3){

                scanf("%d%d",&l,&r);reverse(l+1,r-l+1);print();

            }

            else if(o==4){

                scanf("%d%d",&l,&r);remove(l+1,r-l+1);print();

            }

            else if(o==5){

                scanf("%d%d",&l,&v);

                for(int i=1;i<=v;i++){ scanf("%d",&a[i]);}

                insert(l+1,v);

                print();

            }

            else if(o==6){

                scanf("%d%d",&l,&r);

                cout<<query(l+1,r-l+1)<<endl;

            }

        }

    }

    return 0;

}

 

 

 

你可能感兴趣的:(play)