SplayTree实现_简化版

(1)

template<class DataType>
class SplayTree{
#define MAXN 1000010
private:
    int ch[MAXN][2],pre[MAXN],pool[MAXN];
    DataType key[MAXN];
    int top,root,tot;
    int malloc(DataType dt){
        int x;
        if(top!=0) x = pool[--top];
        else x = ++tot;
        key[x] = dt;
        return x;
    }
    void free(int x){
        ch[x][0] = ch[x][1] = pre[x] = 0;
        pool[top++] = x;
    }
    void rotate(int x,int f){//f == 0 为zag f == 1 为zig
        int y = pre[x]; int z = pre[y];
        ch[y][!f] = ch[x][f];
        if(ch[y][!f]) pre[ch[y][!f]] = y;
        ch[x][f] = y;
        pre[y] = x;
        if(z) ch[z][0]==y?ch[z][0] = x:ch[z][1] = x;
        pre[x] = z;
    }
    void splay(int x,int &rt){
        int y,z;
        while(pre[x]){
            y = pre[x]; z = pre[y];
            if(!z){
                rotate(x,ch[y][0]==x);
            }else{
                int f = ch[z][0]==y;
                if(ch[y][!f]==x){
                    rotate(y,f); rotate(x,f);
                }else{
                    rotate(x,!f); rotate(x,f);
                }
            }
        }
        rt = x;
    }
    int find_help(DataType k,int rt){
        if(!rt) return 0;
        if(k==key[rt])return rt;
        return k<key[rt]?find_help(k,ch[rt][0]):find_help(k,ch[rt][1]);
    }
    int findmax_help(int rt){
        int x = rt;
        while(x&&ch[x][1]) x = ch[x][1];
        return x;
    }
    int findmin_help(int rt){
        int x = rt;
        while(x&&ch[x][0]) x = ch[x][0];
        return x;
    }
    int join(int rt1,int rt2){
        if(!rt1) return rt2;
        if(!rt2) return rt1;
        int x = findmax_help(rt1);
        splay(x,rt1);
        ch[rt1][1] = rt2;
        pre[rt2] = rt1;
        return rt1;
    }
    int split(DataType k,int &rt,int &rt1,int &rt2){
        int x = find_help(k,rt);
        if(!x) return 0;
        splay(x,rt);
        rt1 = ch[rt][0]?ch[rt][0]:0;
        pre[rt1] = 0;
        rt2 = ch[rt][1]?ch[rt][1]:0;
        pre[rt2] = 0;
        return rt;
    }
    int insert_help(DataType k,int &rt,int father){
        if(!rt){
            rt = malloc(k);
            pre[rt] = father;
            return rt;
        }
        return insert_help(k,ch[rt][!(k<key[rt])],rt);
    }
public:
    void insert(DataType k){
        int x = insert_help(k,root,0);
        splay(x,root);
    }
    void remove(DataType k){
        int rt1,rt2;
        int x = split(k,root,rt1,rt2);
        if(!x) return;
        free(x);
        root = join(rt1,rt2);
    }
    int findmax(DataType &k){
        int x = findmax_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
    int findmin(DataType &k){
        int x = findmin_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
};

(2)

template<class DataType>
class SplayTree{
#define MAXN 1000010
private:
    int ch[MAXN][2],pre[MAXN],pool[MAXN];
    DataType key[MAXN];
    int top,root,tot;
    int malloc(DataType dt){
        int x;
        if(top!=0) x = pool[--top];
        else x = ++tot;
        key[x] = dt;
        return x;
    }
    void free(int x){
        ch[x][0] = ch[x][1] = pre[x] = 0;
        pool[top++] = x;
    }
    void rotate(int x,int f){//f == 0 为zag f == 1 为zig
        int y = pre[x]; int z = pre[y];
        ch[y][!f] = ch[x][f];
        if(ch[y][!f]) pre[ch[y][!f]] = y;
        ch[x][f] = y;
        pre[y] = x;
        if(z) ch[z][0]==y?ch[z][0] = x:ch[z][1] = x;
        pre[x] = z;
    }
    void splay(int x,int &rt){
        int y,z;
        while(pre[x]){
            y = pre[x]; z = pre[y];
            if(!z){
                rotate(x,ch[y][0]==x);
            }else{
                int f = ch[z][0]==y;
                if(ch[y][!f]==x){
                    rotate(y,f); rotate(x,f);
                }else{
                    rotate(x,!f); rotate(x,f);
                }
            }
        }
        rt = x;
    }
    int find_help(DataType k,int rt){
        if(!rt) return 0;
        if(k==key[rt])return rt;
        return k<key[rt]?find_help(k,ch[rt][0]):find_help(k,ch[rt][1]);
    }
    int findmax_help(int rt){
        int x = rt;
        while(x&&ch[x][1]) x = ch[x][1];
        return x;
    }
    int findmin_help(int rt){
        int x = rt;
        while(x&&ch[x][0]) x = ch[x][0];
        return x;
    }
    int insert_help(DataType k,int &rt,int father){
        if(!rt){
            rt = malloc(k);
            pre[rt] = father;
            return rt;
        }
        return insert_help(k,ch[rt][!(k<key[rt])],rt);
    }
    void remove_help(DataType k,int &rt,int father){
        if(!rt) return;
        if(k==key[rt]){
            if(ch[rt][0]==0||ch[rt][1]==0){
                int x = rt;
                rt = ch[rt][0]+ch[rt][1];
                if(rt){ pre[rt] = father; splay(rt,root); }
                free(x);
                return;
            }
            int x = findmin_help(ch[rt][1]);
            key[rt] = key[x];
            remove_help(key[rt],ch[rt][1],rt);
            splay(rt,root);
        }
        remove_help(k,ch[rt][!(k<key[rt])],rt);
    }
public:
    void insert(DataType k){
        int x = insert_help(k,root,0);
        splay(x,root);
    }
    void remove(DataType k){
        remove_help(k,root,0);
    }
    int findmax(DataType &k){
        int x = findmax_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
    int findmin(DataType &k){
        int x = findmin_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
};


你可能感兴趣的:(tree)