平衡树及其可持久化

可能是以前受一个指针版本treap的影响,一直是以pair格式去写treap的。
原来引用&这么方便啊。
这篇文章的代码,都是我找到的一些十分优雅的写法。

void split(int i,int x,int &a,int &b)
{
    if(!i) a=0,b=0;
    else {
        if(w[i]<=x) a=i,split(t[i].r,x,t[i].r,b);
        else b=i,split(t[i].l,x,a,t[i].l);
        update(i);
    }
}

那么我们在给平衡树持久化时,就可以方便的进行path copy。
记住一点,改谁就copy谁。(可以理解为跟update一起出现)

int merge(int a,int b) //此Treap是按大根堆去维护
{
    if(!a||!b)return a+b; //这个好帅
    if(t[a].rnd>t[b].rnd)
    {
        int p=++cnt;t[p]=t[a]; //path copy
        t[p].r=merge(t[p].r,b); //普通的treap是把所有p变成a或b的,这里可以体现出来持久化
        update(p);return p;
    }
    else
    {
        int p=++cnt;t[p]=t[b];
        t[p].l=merge(a,t[p].l);
        update(p);return p;
    }
}
void split(int now,int k,int &x,int &y) //这里的k是一个权值,表示把原平衡树分为小于等于k,和大于k的两部分
{
    if(!now)x=y=0;
    else
    {
        if(t[now].v<=k)
        {
            x=++cnt;t[x]=t[now];
            split(t[x].r,k,t[x].r,y);
            update(x);
        }
        else 
        {
            y=++cnt;t[y]=t[now];
            split(t[y].l,k,x,t[y].l);
            update(y);
        }
    }
}

附一个我自己写的丑丑的LG3835代码:

#include 
#include 
#include 
#include 
#include 
#define N 500010
using namespace std;
inline char gc() {
    static char now[1<<16], *S, *T;
    if(S == T) {T = (S = now) + fread(now, 1, 1<<16, stdin); if(S == T) return EOF;}
    return *S++;
}
inline int read() {
    int x = 0, f = 1; char c = gc();
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = gc();}
    while(c >= '0' && c <= '9') {x = x * 10 + c - 48; c = gc();}
    return x * f;
}
vector<int> pos;
struct treap {int l, r, sz, w, h;}t[N*50];
int root[N], len = 0, n;
inline int newnode(int x) {int ret; if(pos.size()) ret = pos.back(), pos.pop_back(); else ret = ++len; t[ret].l = 0; t[ret].r = 0; t[ret].sz = 1; t[ret].h = rand(); t[ret].w = x; return ret;}
inline void update(int p) {t[p].sz = t[t[p].l].sz + 1 + t[t[p].r].sz;}
int merge1(int i, int j) {
    if(!i || !j) return i + j;
    if(t[i].h < t[j].h) {
        int p = newnode(0); t[p] = t[i];
        t[p].r = merge1(t[p].r, j);
        update(p); return p;
    }else {
        int p = newnode(0); t[p] = t[j];
        t[p].l = merge1(i, t[p].l);
        update(p); return p;
    }
}
void split1(int p, int x, int &a, int &b) {//first: '<=' second: '>'
    if(!p) a = b = 0;
    else {
        if(t[p].w <= x) {
            a = newnode(0); t[a] = t[p];
            split1(t[a].r, x, t[a].r, b);
            update(a);
        }else {
            b = newnode(0); t[b] = t[p];
            split1(t[b].l, x, a, t[b].l);
            update(b);
        }
    }
}
void split2(int p, int k, int &a, int &b) {
    if(!p) a = b = 0;
    else {
        if(t[t[p].l].sz >= k) {
            b = newnode(0); t[b] = t[p];
            split2(t[b].l, k, a, t[b].l);
            update(b);
        }else {
            a = newnode(0); t[a] = t[p];
            split2(t[a].r, k - 1 - t[t[p].l].sz, t[a].r, b);
        }
    }
}
inline void insert1(int p, int x) {
    int a, b; split1(root[p], x, a, b);
    int nx = newnode(x);
    root[p] = merge1(a, merge1(nx, b));
}
inline void delete2(int p, int x) {
    int a, b, c, d;
    split1(root[p], x - 1, a, b);
    split2(b, 1, c, d);
    if(t[c].w == x) {
        pos.push_back(c);
        root[p] = merge1(a, d);
    }else root[p] = merge1(a, merge1(c, d));
}
inline int getrank(int p, int x) {
    int a, b;
    split1(root[p], x - 1, a, b);
    int ret = t[a].sz + 1;
    root[p] = merge1(a, b);
    return ret;
}
inline int getsa(int p, int k) {
    int a, b, c, d;
    split2(root[p], k - 1, a, b);
    split2(b, 1, c, d);
    int ret = t[c].w;
    root[p] = merge1(a, merge1(c, d));
    return ret;
}
inline int pre(int p, int x) {
    int a, b, c, d;
    split1(root[p], x - 1, a, b);
    split2(a, t[a].sz - 1, c, d);
    int ret = -0x7fffffff;
    if(d) ret = t[d].w;
    root[p] = merge1(c, merge1(d, b));
    return ret;
}
inline int nxt(int p, int x) {
    int a, b, c, d;
    split1(root[p], x, a, b);
    split2(b, 1, c, d);
    int ret = 0x7fffffff;
    if(c) ret = t[c].w;
    root[p] = merge1(a, merge1(c, d));
    return ret;
}
int main() {
    srand(20011118);
    n = read();
    for(int i = 1; i <= n; ++i) {
        int tim = read(), opt = read(), x = read();
        root[i] = root[tim];
        if(opt == 1) insert1(i, x);
        if(opt == 2) delete2(i, x);
        if(opt == 3) printf("%d\n", getrank(i, x));
        if(opt == 4) printf("%d\n", getsa(i, x));
        if(opt == 5) printf("%d\n", pre(i, x));
        if(opt == 6) printf("%d\n", nxt(i, x));
    }
    return 0;
}

你可能感兴趣的:(数据结构)