注:此代码非本人所写,实现方式为自底向上,在做ACM题目时比教科书上的自顶向下在数据处理上要好操作。
由于个人现在比较钟爱指针,所以从网上找了个指针实现的源码,只是添加了点个人的理解注释和一些其他操作。
http://dongxicheng.org/structure/splay-tree/此博客对Splay Treee分析比较透彻,需要注意的是,Splay Tree的旋转操作与其他平衡树如Red Black Tree的操作在代码上有些不同,具体可以画图得出。
理解了Splay Tree的Splay函数和其对区间操作的方法,基本上就掌握Splay Tree了。
#include <iostream> #include <cstdio> #include <cstdlib> using namespace std; int const maxn=300000; #define bint __int64 struct node { int value,size; int add; bint sum; node *ch[2], *pre; }tree[maxn],*Null, *root;//设置一个Null指针是一个亮点 int num[maxn]; int n,m; int top=0; node *New_Node(int x) { node *p=&tree[top++]; p->ch[0]=p->ch[1]=p->pre=Null; p->add=0; p->sum=x; p->size=1; p->value=x; return p; } void Push_Down(node *x) { if (x == Null) return; x->value+=x->add; x->ch[0]->add+=x->add; x->ch[1]->add+=x->add; if(x->ch[0]!=Null) x->ch[0]->sum+=(bint)(x->add)*(x->ch[0]->size); if(x->ch[1]!=Null) x->ch[1]->sum+=(bint)(x->add)*(x->ch[1]->size); x->add=0; } void Update(node *x) { if (x == Null) return; x->size = x->ch[0]->size + x->ch[1]->size + 1; x->sum = (bint)x->value+ x->ch[0]->sum + x->ch[1]->sum; } //这里的旋转和redblack不同,是将节点x向上旋 void Rotate(node *x, int c)//c=0: left rotate c=1: right rotate { node *y = x->pre; Push_Down(y); Push_Down(x); y->ch[! c] = x->ch[c]; x->pre = y->pre; if (x->ch[c] != Null) x->ch[c]->pre = y; if (y->pre != Null) y->pre->ch[y->pre->ch[1] == y] = x;//判断原来y->pre的左或右孩子是y y->pre = x; x->ch[c] = y; if (y == root) root = x; Update(y); } //将节点x放到节点f的下面 void Splay(node *x, node *f) { Push_Down(x); while(x->pre != f) { if (x->pre->pre == f) Rotate(x, x->pre->ch[0] == x); else { node *y = x->pre, *z = y->pre; if (z->ch[0] == y)//left if (y->ch[0] == x) //LL Rotate(y, 1), Rotate(x, 1); else Rotate(x, 0), Rotate(x, 1);//LR else//right if (y->ch[1] == x) //RR Rotate(y, 0), Rotate(x, 0); else Rotate(x, 1), Rotate(x, 0);//RL } } Update(x); } //找到处在中序遍历第K个节点,并将其旋转到节点f的下面 void Select(int k, node *f) { node *now=root; while(true) { Push_Down(now); int tmp = now->ch[0]->size; if (tmp + 1 == k) break; if (k <= tmp) now = now->ch[0]; else now = now->ch[1], k -= tmp + 1; } //printf("Select: %d\n",now->value); Splay(now, f); } node *Make_Tree(int l, int r, node *fa) { if (l > r) return Null; int mid = (l + r) >> 1; node *p = New_Node(num[mid]); p->ch[0] = Make_Tree(l, mid-1, p); p->ch[1] = Make_Tree(mid+1, r, p); p->pre = fa; Update(p); return p; } void remove(int left,int right){//删除区间[left,right] Select(left,Null); Select(right+2,root); root->ch[1]->ch[0] = Null; Splay(root->ch[1],Null); } void ADD(int left,int right,int cnt){ Select(left,Null); Select(right+2,root); root->ch[1]->ch[0]->add += cnt; root->ch[1]->ch[0]->sum+=(bint)cnt*(root->ch[1]->ch[0]->size); Splay(root->ch[1]->ch[0],Null); } node *makeTree(int l, int r, int value, node *fa) { if (l > r) return Null; int mid = (l + r) >> 1; node *p = New_Node(value); p->ch[0] = makeTree(l, mid-1,value, p); p->ch[1] = makeTree(mid+1, r,value, p); p->pre = fa; Update(p); return p; } void insert(int pos,int cnt,int value){//在pos位置后面连续插入cnt个值为value的数 Select(pos+1,Null); Select(pos+2,root); root->ch[1]->ch[0] = makeTree(1,cnt,value,root->ch[1]); Splay(root->ch[1]->ch[0],Null); } void print(node *root){ node *p = root; if(p!=Null){ if(p->ch[0]!=Null)print(p->ch[0]); printf("%d,parent=%d,size=%d\n",p->value,p->pre->value,p->pre->size); if(p->ch[1]!=Null) print(p->ch[1]); } } int main() { top=0; scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&num[i]); } Null=New_Node(0); Null->size=0; //head root=New_Node(-1); //tail root->ch[1]=New_Node(-1); root->ch[1]->pre=root; Update(root); //root->ch[1] is the really root root->ch[1]->ch[0]=Make_Tree(1,n,root->ch[1]); Update(root->ch[1]); Update(root); //print(root); char s[2]; for(int i=0;i<m;i++) { scanf("%s",s); if(s[0]=='C') { int a,b,c; scanf("%d%d%d",&a,&b,&c); ADD(a,b,c); // Select(a,Null); // print(root); // Select(b+2,root); // cout<<"=================================="<<endl; // print(root); //root->ch[1]->ch[0]->add+=c; //root->ch[1]->ch[0]->sum+=(bint)c*(root->ch[1]->ch[0]->size); } else { int a,b; scanf("%d%d",&a,&b); //先把第a-1的数节点放到Null下,再将第b+1的数节点放到root下 //(即root的右子树),那么剩下(root左子树就是区间[a,b]) Select(a,Null);//由于root的size是1,所以这里是a-1+1=a,a-1就变成root了 //print(root); Select(b+2,root);//同样这里是b+1+1=b+2,a-1的右子树 // cout<<"=================================="<<endl; // print(root); printf("%I64d\n",root->ch[1]->ch[0]->sum); } } return 0; }