splay模板(功能最全)

回想起来splay大概是高中时的噩梦吧,手敲splay的确挑战太大了,不过到了acm时期,其实应该是不用惧怕splay的,把板子准备好就问题不大。

poj3580
虽然不是最简单的题,但是可以作为板子题
https://vjudge.net/problem/POJ-3580

题目大意:
要求实现一种数据结构,支持对一个数字序列的 6 种操作:
ADD x y val:第 x 个数到第 y 个数之间的数每个加 D;
REVERSE x y:第 x 个数到第 y 个数之间全部数翻转;
REVOLVE x y c:第 x 个数到第 y 个数之间的数,向后循环流动 c 次,即后面 c
个数变成这段子序列的最前面 c 个,前面的被挤到后面。
INSERT x P:在第 x 个数后面插入一个数 P。
DELETE x:删除第 x 个数。
MIN x y:求第 x 个数到第 y 个数之间的最小数字。

解释一下我的模板,首先有一
个很重要的操作:提取区间,要操作[a,b]区间,那么将 a-1 先旋转到根节点,
再将 b+1 旋转到根节点的右子节点,那么由中序遍历的性质可以知道,此时根
节点的右子节点的左子树是我们要操作的区间。因此在提取了操作的区间以
后:
1、对于区间加的操作, 只需要加上一个加标记,并且修改这个节点维护的区间
最值信息;
2、对于区间翻转操作,只需要加上一个翻转标记;
3、对于区间插入操作,如果插入到 pos,那么我们可以选择区间[pos+1,pos]
(对,大的在前)也就是把 pos 旋到根,把 pos+1 旋到根的右子结点,并且此
时根的右子结点的左子树是空的,然后插入到根的右子结点的左子结点上(这
个操作也可以把砍下来的一棵子树重新拼接上去);
4、对于删除结点操作,只需将该结点放入回收队列中,并且修改根的右子结点
的信息即可(砍树的时候只需要断开父子关系并且修改父亲结点信息);
5、对于求区间最值操作,由于每一个节点本身已经维护了一个子树信息,直接
提取即可;
6、对于区间移位操作,在右移之后,可以发现[a,b]被分为两个区间[a,bc] [b-c+1,b],将后者插入到前者之前即可( c 可能很大,需要先取模)。
所有操作除了翻转以外都需要先把延迟标记推下来,否则就会更新不及时而出
错,具体见代码。
另外,由于内存占用比较多,所以我加上了一个内存回收队列,把删掉的节点
放入到那里去,然后插入新节点的时候先看回收队列中有无可用节点,没有的
话再向内存池申请。
手调过的板子:

#include 
#include 
#include 
#include 
#define rep(i) for (int i=0; i
using namespace std;
typedef long long ll;
const int N=200005, inf=0x3f3f3f3f;

typedef struct splaynode* node;
struct splaynode {
    node pre, ch[2];
    int value, lazy, min;
    int size, rev;
    void init(int _value) {
        pre=ch[0]=ch[1]=NULL;
        min=value=_value;
        lazy=rev=0;
        size=1;
    }
}mem[N];
int memtop;

stack S;
node root;

inline int getsize(node &x) {
    return x?x->size:0;
}

void pushdown(node &x) {
    if (!x) return;
    if (x->lazy) {
        int w = x->lazy;
        x->value += w;
        if (x->ch[0]) {
            x->ch[0]->lazy += w;
            x->ch[0]->min += w;
        }
        if (x->ch[1]) {
            x->ch[1]->lazy += w;
            x->ch[1]->min += w;
        }
        x->lazy = 0;
    }
    if (x->rev) {
        node t = x->ch[0];
        x->ch[0] = x->ch[1];
        x->ch[1] = t;
        x->rev = 0;
        if (x->ch[0]) x->ch[0]->rev ^= 1;
        if (x->ch[1]) x->ch[1]->rev ^= 1;
    }
}

void update(node &x) {
    if (!x) return;
    x->size = 1;
    x->min = x->value;
    if (x->ch[0]) {
        x->min = min(x->min, x->ch[0]->min);
        x->size += x->ch[0]->size;
    }
    if (x->ch[1]) {
        x->min = min(x->min, x->ch[1]->min);
        x->size += x->ch[1]->size;
    }
}

void rotate(node &x, int d) {
    node y = x->pre;
    pushdown(y);
    pushdown(x);
    pushdown(x->ch[d]);
    y->ch[!d] = x->ch[d];
    if (x->ch[d] != NULL) x->ch[d]->pre = y;
    x->pre = y->pre;
    if (y->pre != NULL)
        if (y->pre->ch[0] == y) y->pre->ch[0] = x; else y->pre->ch[1] = x;
    x->ch[d] = y;
    y->pre = x;
    update(y);
    if (y == root) root = x;
}

void splay(node &src, node &dst) {
    pushdown(src);
    while (src!=dst) {
        if (src->pre==dst) {
            if (dst->ch[0]==src) rotate(src, 1); else rotate(src, 0);
            break;
        }
        else {
            node y=src->pre, z=y->pre;
            if (z->ch[0]==y) {
                if (y->ch[0]==src) {
                    rotate(y, 1);
                    rotate(src, 1);
                }else {
                    rotate(src, 0);
                    rotate(src, 1);
                }
            }
            else {
                if (y->ch[1]==src) {
                    rotate(y, 0);
                    rotate(src, 0);
                }else {
                    rotate(src, 1);
                    rotate(src, 0);
                }
            }
            if (z==dst) break;
        }
        update(src);
    }
    update(src);
}

void select(int k, node &f) {
    int tmp;
    node t = root;
    while (1) {
        pushdown(t);
        tmp = getsize(t->ch[0]);
        if (k == tmp + 1) break;
        if (k <= tmp) t = t->ch[0];
        else {
            k -= tmp + 1;
            t = t->ch[1];
        }
    }
    pushdown(t);
    splay(t, f);
}

inline void selectsegment(int l,int r) {
    select(l, root);
    select(r + 2, root->ch[1]);
}

void insert(int pos, int value) {  //在pos位置后面插入一个新值value
    selectsegment(pos + 1, pos);
    node t;
    node x = root->ch[1];
    pushdown(root);
    pushdown(x);
    if (!S.empty()) {
        t = S.top();
        S.pop();
    } else {
        t = &mem[memtop++];
    }
    t->init(value);
    t->ch[1] = x;
    x->pre = t;
    root->ch[1] = t;
    t->pre = root;
    splay(x, root);
}

void add(int a,int b, int value) {  //区间[a,b]中的数都加上value
    selectsegment(a, b);
    node x = root->ch[1]->ch[0];
    pushdown(x);
    update(x);
    x->min += value;
    x->lazy += value;
    splay(x, root);
}

void reverse(int a, int b) {   //区间[a,b]中的数翻转
    selectsegment(a, b);
    root->ch[1]->ch[0]->rev ^= 1;
    node x = root->ch[1]->ch[0];
    splay(x, root);
}

void revolve(int a, int b, int t) { //区间[a,b]中的数向后循环移t位
    node p1, p2;
    selectsegment(a, b);
    select(b + 1 - t, root->ch[1]->ch[0]);
    p1 = root->ch[1]->ch[0];
    pushdown(p1);
    p2 = p1->ch[1];
    p1->ch[1] = NULL;

    select(a + 1, root->ch[1]->ch[0]);
    p1 = root->ch[1]->ch[0];
    pushdown(p1);
    p1->ch[0] = p2;
    p2->pre = p1;

    splay(p2, root);
}

int getmin(int a, int b) {   //取[a,b]中最小的值
    selectsegment(a, b);
    node x = root->ch[1];
    pushdown(x);
    x = x->ch[0];
    pushdown(x);
    update(x);
    return x->min;
}

void erase(int pos) {               //抹去第pos个元素
    selectsegment(pos, pos);
    pushdown(root->ch[1]);
    S.push(root->ch[1]->ch[0]);        //回收内存
    root->ch[1]->ch[0] = NULL;
    node x = root->ch[1];
    splay(x, root);
}

void initsplaytree(int *a, int n) {
    memtop = 0;
    root = &mem[memtop++];
    root->init(inf);
    root->ch[1] = &mem[memtop++];
    root->ch[1]->init(inf);
    while (!S.empty()) S.pop();
    rep(i) insert(i, a[i]);
}



/*----------Splay Template Over----------*/
int v[N];
int main() {
    int n, m;
    scanf("%d", &n);
    rep(i) scanf("%d", &v[i]);
    initsplaytree(v, n);
    scanf("%d", &m);
    while (m--) {
        char s[50];
        scanf("%s", s);
        if (s[0]=='A') {
            int l, r, d;
            scanf("%d%d%d", &l, &r, &d);
            add(l, r, d);
        }
        if (s[0]=='R') {
            int l, r;
            scanf("%d%d",&l, &r);
            if (s[3]=='E') reverse(l, r);
            else {
                int k;
                scanf("%d", &k);
                int tn=r-l+1;
                k=(k%tn+tn)%tn;
                if (l==r || k==0) continue;
                revolve(l, r, k);
            }
        }
        if (s[0]=='I') {
            int x, d;
            scanf("%d%d", &x, &d);
            insert(x, d);
        }
        if (s[0]=='D') {
            int x;
            scanf("%d", &x);
            erase(x);
        }
        if (s[0]=='M'){
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%d\n", getmin(l, r));
        }
    }
}

单点赋值: 有一种偷懒的做法(修改第 k 个学生):查询[k,k]的最小值,然后给[k,k]区间加上差值即可。
区间赋值等操作应该是和线段数中的是一样的,可以参考线段树中的各种lazy操作。

另外比较常见的是区间移动操作,下面把他封装成函数

void cutandmove(int a,int b,int c) //移动区间[l,r]到位置c后
{
    selectsegment(a,b);
    node CutRoot=root->ch[1]->ch[0];
    CutRoot->pre=NULL;
    root->ch[1]->size-=CutRoot->size;
    root->ch[1]->ch[0]=NULL;

    selectsegment(c+1,c);

    CutRoot->pre=root->ch[1];
    root->ch[1]->ch[0]=CutRoot;
    root->ch[1]->size+=CutRoot->size;
//切树操作的话,就先选择要切的区间,然后断开根的右子结
点和其左子结点的联系,把要接上的节点旋转到根的右子结点出并清空其左子
结点,再把切下来的子树接上去即可。
}


void cut(int a,int b)  //删除区间[l.r]
{
    selectsegment(a,b);
    node CutRoot=root->ch[1]->ch[0];
    CutRoot->pre=NULL;
    root->size-=CutRoot->size;
    root->ch[1]->size-=CutRoot->size;
    root->ch[1]->ch[0]=NULL;
}
vector ans;
void inorder(node x)
{
    if (!x) return;
    pushdown(x);
    inorder(x->ch[0]);
    if (x->value!=inf) ans.push_back(x->value);
    inorder(x->ch[1]);
}
//inorder(root);

借用poj3468写了个全模板,基本上线段树操作齐了吧。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define rep(i) for (int i=0; i
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
typedef long long ll;
const int N=100005, inf=0x3f3f3f3f;

typedef struct splaynode* node;
struct splaynode {
    node pre, ch[2];
    ll value, lazy, max, sum;
    int size, rev;
    void init(int _value) {
        pre=ch[0]=ch[1]=NULL;
        max=value=sum=_value;
        lazy=rev=0;
        size=1;
    }
}mem[N];
int memtop;

stack S;
node root;

inline int getsize(node &x) {
    return x ? x->size : 0;
}

void pushdown(node &x) {
    if (!x) return;
    if (x->lazy) {
        ll w = x->lazy;
        x->value += w;
        if (x->ch[0]) {
            x->ch[0]->lazy += w;
            x->ch[0]->max += w;
            x->ch[0]->sum += w*getsize(x->ch[0]);
        }
        if (x->ch[1]) {
            x->ch[1]->lazy += w;
            x->ch[1]->max += w;
            x->ch[1]->sum += w*getsize(x->ch[1]);
        }
        x->lazy = 0;
    }
    if (x->rev) {
        node t = x->ch[0];
        x->ch[0] = x->ch[1];
        x->ch[1] = t;
        x->rev = 0;
        if (x->ch[0]) x->ch[0]->rev ^= 1;
        if (x->ch[1]) x->ch[1]->rev ^= 1;
    }
}

void update(node &x) {
    if (!x) return;
    x->size = 1;
    x->max = x->value;
    x->sum = x->value;
    if (x->ch[0]) {
        x->sum += x->ch[0]->sum;
        x->max = max(x->max, x->ch[0]->max);
        x->size += x->ch[0]->size;
    }
    if (x->ch[1]) {
        x->sum += x->ch[1]->sum;
        x->max = max(x->max, x->ch[1]->max);
        x->size += x->ch[1]->size;
    }
}

void rotate(node &x, int d) {
    node y = x->pre;
    pushdown(y);
    pushdown(x);
    pushdown(x->ch[d]);
    y->ch[!d] = x->ch[d];
    if (x->ch[d] != NULL) x->ch[d]->pre = y;
    x->pre = y->pre;
    if (y->pre != NULL)
        if (y->pre->ch[0] == y) y->pre->ch[0] = x; else y->pre->ch[1] = x;
    x->ch[d] = y;
    y->pre = x;
    update(y);
    if (y == root) root = x;
}

void splay(node &src, node &dst) {
    pushdown(src);
    while (src!=dst) {
        if (src->pre==dst) {
            if (dst->ch[0]==src) rotate(src, 1); else rotate(src, 0);
            break;
        }
        else {
            node y=src->pre, z=y->pre;
            if (z->ch[0]==y) {
                if (y->ch[0]==src) {
                    rotate(y, 1);
                    rotate(src, 1);
                }else {
                    rotate(src, 0);
                    rotate(src, 1);
                }
            }
            else {
                if (y->ch[1]==src) {
                    rotate(y, 0);
                    rotate(src, 0);
                }else {
                    rotate(src, 1);
                    rotate(src, 0);
                }
            }
            if (z==dst) break;
        }
        update(src);
    }
    update(src);
}

void select(int k, node &f) {
    int tmp;
    node t = root;
    while (1) {
        pushdown(t);
        tmp = getsize(t->ch[0]);
        if (k == tmp + 1) break;
        if (k <= tmp) t = t->ch[0];
        else {
            k -= tmp + 1;
            t = t->ch[1];
        }
    }
    pushdown(t);
    splay(t, f);
}

inline void selectsegment(int l,int r) {
    select(l, root);
    select(r + 2, root->ch[1]);
}

void insert(int pos, int value) {  //在pos位置后面插入一个新值value
    selectsegment(pos + 1, pos);
    node t;
    node x = root->ch[1];
    pushdown(root);
    pushdown(x);
    if (!S.empty()) {
        t = S.top();
        S.pop();
    } else {
        t = &mem[memtop++];
    }
    t->init(value);
    t->ch[1] = x;
    x->pre = t;
    root->ch[1] = t;
    t->pre = root;
    splay(x, root);
}

void add(int a,int b, int value) {  //区间[a,b]中的数都加上value
    selectsegment(a, b);
    node x = root->ch[1]->ch[0];
    pushdown(x);
    update(x);
    x->max += value;
    x->lazy += value;
    splay(x, root);
}

void reverse(int a, int b) {   //区间[a,b]中的数翻转
    selectsegment(a, b);
    root->ch[1]->ch[0]->rev ^= 1;
    node x = root->ch[1]->ch[0];
    splay(x, root);
}

void revolve(int a, int b, int t) { //区间[a,b]中的数向后循环移t位
    node p1, p2;
    selectsegment(a, b);
    select(b + 1 - t, root->ch[1]->ch[0]);
    p1 = root->ch[1]->ch[0];
    pushdown(p1);
    p2 = p1->ch[1];
    p1->ch[1] = NULL;

    select(a + 1, root->ch[1]->ch[0]);
    p1 = root->ch[1]->ch[0];
    pushdown(p1);
    p1->ch[0] = p2;
    p2->pre = p1;

    splay(p2, root);
}

ll getmax(int a, int b) {   //取[a,b]中最小的值
    selectsegment(a, b);
    node x = root->ch[1];
    pushdown(x);
    x = x->ch[0];
    pushdown(x);
    update(x);
    return x->max;
}

ll getsum(int a, int b) {
    selectsegment(a, b);
    node x = root->ch[1];
    pushdown(x);
    x = x->ch[0];
    pushdown(x);
    update(x);
    return x->sum;
}

void erase(int pos) {               //抹去第pos个元素
    selectsegment(pos, pos);
    pushdown(root->ch[1]);
    S.push(root->ch[1]->ch[0]);        //回收内存
    root->ch[1]->ch[0] = NULL;
    node x = root->ch[1];
    splay(x, root);
}


void cutandmove(int a,int b,int c)
{
    selectsegment(a,b);
    node CutRoot=root->ch[1]->ch[0];
    CutRoot->pre=NULL;
    root->ch[1]->size-=CutRoot->size;
    root->ch[1]->ch[0]=NULL;

    selectsegment(c+1,c);

    CutRoot->pre=root->ch[1];
    root->ch[1]->ch[0]=CutRoot;
    root->ch[1]->size+=CutRoot->size;
}

void cut(int a,int b)
{
    selectsegment(a,b);
    node CutRoot=root->ch[1]->ch[0];
    CutRoot->pre=NULL;
    root->size-=CutRoot->size;
    root->ch[1]->size-=CutRoot->size;
    root->ch[1]->ch[0]=NULL;
}

vector ans;
void inorder(node x)
{
    if (!x) return;
    pushdown(x);
    inorder(x->ch[0]);
    if (x->value!=inf) ans.push_back(x->value);
    inorder(x->ch[1]);
}

void initsplaytree(ll *a, int n) {
    memtop = 0;
    root = &mem[memtop++];
    root->init(inf);
    root->ch[1] = &mem[memtop++];
    root->ch[1]->init(inf);
    while (!S.empty()) S.pop();
    rep(i) insert(i, a[i]);
}


/*----------Splay Template Over----------*/
ll v[N];
int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    rep(i) scanf("%lld", &v[i]);
    initsplaytree(v, n);
    while (m--) {
        char s[50];
        scanf("%s", s);
        int l, r;
        scanf("%d%d", &l, &r);
        if (s[0]=='Q') printf("%lld\n", getsum(l, r));
        else {
            int d;
            scanf("%d", &d);
            add(l, r, d);
        }
    }
    return 0;
}

你可能感兴趣的:(acm)