SplayTree实现

template<class DataType>
class SplayTree{
#define null 0
private:
    int MAXSIZE;
    int *l,*r,*p,*pool,*depth;
    int top,root,tot;
    DataType *key;
    int malloc(DataType k){
        int  node;
        if(top!=0){
            node = pool[--top];
        }else{
            node = ++tot;
        }
        key[node] = k;
        return node;
    }
    void free(int node){
        l[node] = r[node] = p[node] = null;
        pool[top++] = node;
    }
    void zig(int x){
        int fa = p[x];
        int ga = p[fa];
        l[fa] = r[x];
        p[l[fa]] = fa;
        r[x] = fa;
        p[fa] = x;
        if(ga) l[ga]==fa?l[ga] = x:r[ga] = x;
        p[x] = ga;
    }
    void zag(int x){
        int fa = p[x];
        int ga = p[fa];
        r[fa] = l[x];
        p[r[fa]] = fa;
        l[x] = fa;
        p[fa] = x;
        if(ga) l[ga]==fa?l[ga] = x:r[ga] = x;
        p[x] = ga;
    }
    /**
    伸展到根rt处
    **/
    void splay(int x,int &rt){
        while(p[x]){
            int fa = p[x],ga = p[p[x]];
            if(ga==null){
                l[fa]==x?zig(x):zag(x);
            }else{
                if(l[ga]==fa){
                    if(l[fa]==x){
                        zig(fa);
                        zig(x);
                    }else{
                        zag(x);
                        zig(x);
                    }
                }else{
                    if(r[fa]==x){
                        zag(fa);
                        zag(x);
                    }else{
                        zig(x);
                        zag(x);
                    }
                }
            }
        }
        rt = x;
    }
    int find_help(DataType goal,int rt){
        if(rt==null) return null;
        if(key[rt]==goal)return rt;
        return goal<key[rt]?find_help(goal,l[rt]):find_help(goal,r[rt]);
    }
    //fa = p[rt]
    int insert_help(DataType goal,int &rt,int fa){
        if(rt==null){
            rt = malloc(goal);
            p[rt] = fa;//必须得修改新结点的父亲指针
            return rt;
        }
        return goal<key[rt]?insert_help(goal,l[rt],rt):insert_help(goal,r[rt],rt);
    }
    int findmax_help(int rt){
        int node = rt;//考虑空树的情况
        while(node!=null&&r[node]!=null) node = r[node];
        return node;
    }
    int findmin_help(int rt){
        int node = rt;
        while(node!=null&&l[node]!=null) node = l[node];
        return node;
    }
    /**
    将rt1 rt2 两颗BST合并成一棵BST
    返回新BST的根
    **/
    int join(int &rt1,int &rt2){
        if(rt1==null)return rt2;
        if(rt2==null)return rt1;
        int node = findmax_help(rt1);
        splay(node,rt1); r[rt1] = rt2; p[rt2] = rt1;
        return rt1;
    }
    /**
    根据goal将一棵BST分裂成两颗BST
    返回goal结点的位置
    **/
    int split(DataType goal,int &rt,int &rt1,int &rt2){
        int node = find_help(goal,rt);
        if(node!=null){
            splay(node,rt);
            rt1 = l[rt];
            rt2 = r[rt];
            p[rt1] = p[rt2] = null;
            l[rt] = r[rt] = null;
        }
        return node;
    }
    //更新各结点深度
    int refreshDepth(int rt,int dep){
        if(rt==null)return 0;
        depth[rt] = dep;
        int t1 = refreshDepth(l[rt],dep+1);
        int t2 = refreshDepth(r[rt],dep+1);
        return max(max(t1,t2),dep);
    }
public:
    SplayTree(int maxsize){
        MAXSIZE = maxsize;
        depth = new int[MAXSIZE];
        l = new int[MAXSIZE]; memset(l,0,sizeof(int)*MAXSIZE);
        r = new int[MAXSIZE]; memset(r,0,sizeof(int)*MAXSIZE);
        p = new int[MAXSIZE]; memset(p,0,sizeof(int)*MAXSIZE);
        pool = new int[MAXSIZE]; memset(pool,0,sizeof(int)*MAXSIZE);
        key = new DataType[MAXSIZE]; memset(key,0,sizeof(DataType)*MAXSIZE);
        top = root = tot = 0;
    }
    ~SplayTree(){
        delete[] depth;
        delete[] l;
        delete[] r;
        delete[] p;
        delete[] pool;
        delete[] key;
    }
    int find(DataType goal){
        int node = find_help(goal,root);
        if(node!=null){
            splay(node,root);
            return 1;
        }
        return 0;
    }
    int findmax(){
        int node = findmax_help(root);
        if(node!=null) splay(node,root);
        return node;
    }
    int findmin(){
        int node = findmin_help(root);
        if(node!=null) splay(node,root);
        return node;
    }
    void insert(DataType goal){
        int node = insert_help(goal,root,p[root]);
        splay(node,root);
    }
    int remove(DataType goal){
        int rt1,rt2;
        int node = split(goal,root,rt1,rt2);
        if(node!=null){
            free(root);
            root = join(rt1,rt2);
        }
        return node;
    }
    DataType getValue(int node){
        return key[node];
    }
    /*debug 查看树形*/
    void bfs(){
        int termi = refreshDepth(root,1);
        cout<<termi<<endl;
        queue<int> Q;
        Q.push(root);
        int flag = 1,dep = 1,cond = dep;
        while(1){
            int now = Q.front(); Q.pop();

            if(cond==0){
                dep++;
                flag = flag*2;
                cond = flag;
                cout<<endl;
            }
            if(dep>termi)break;
            cout<<"["<<key[now]<<"]"; cond--;
            Q.push(l[now]);
            Q.push(r[now]);
        }
    }
};

你可能感兴趣的:(tree)