伸展树的基本操作。。。但是题目数据量比较大,加了预编译指令用C++交才勉强过。。。
#include <iostream> #include <queue> #include <stack> #include <map> #include <set> #include <bitset> #include <cstdio> #include <algorithm> #include <cstring> #include <climits> #include <cstdlib> #include <cmath> #include <time.h> #pragma comment(linker, "/STACK:16777216") #define maxn 300005 #define maxm 400005 #define eps 1e-10 #define mod 1000000007 #define lowbit(x) (x&(-x)) #define ls o<<1 #define rs o<<1 | 1 #define lson o<<1, L, mid #define rson o<<1 | 1, mid+1, R typedef long long LL; //typedef int LL; using namespace std; struct node { int s, v, flip; node* ch[2]; inline int cmp(int x) { if(x == ch[0]->s+1) return -1; if(x > ch[0]->s+1) return 1; else return 0; } inline void maintain(void) { s = ch[0]->s + ch[1]->s + 1; } inline void pushdown(void) { if(!flip) return; swap(ch[0], ch[1]); ch[0]->flip^=1; ch[1]->flip^=1; flip = 0; } }*null, *root, C[maxn], *top; void rotate(node* &o, int d) { node* k = o->ch[d^1]; o->ch[d^1] = k->ch[d], k->ch[d] = o; o->maintain(), k->maintain(), o = k; } void splay(node* &o, int k) { o->pushdown(); int d = o->cmp(k); if(d == -1) return; if(d == 1) k -= o->ch[0]->s + 1; node* p = o->ch[d]; p->pushdown(); int d2 = p->cmp(k); int k2 = (d2 == 0 ? k : k - p->ch[0]->s - 1); if(d2 != -1) { splay(p->ch[d2], k2); if(d == d2) rotate(o, d^1); else rotate(o->ch[d], d); } rotate(o, d^1); } node* merge(node* left, node* right) { splay(left, left->s); left->ch[1] = right; left->maintain(); return left; } void split(node* o, int k, node* &left, node* &right) { splay(o, k); right = o->ch[1]; o->ch[1] = null; left = o; left->maintain(); } char s[10]; int n, m, cnt; void init(void) { cnt = 0; top = C; null = top++; null->ch[0] = null->ch[1] = NULL; null->v = null->s = null->flip = 0; root = top++; root->ch[0] = root->ch[1] = null; root->v = root->s = root->flip = 0; root->maintain(); } void build(void) { node *p; for(int i = 1; i <= n; i++) { p = top++; p->ch[0] = root; p->ch[1] = null; p->v = i; p->flip = 0; p->maintain(); root = p; } } void print(node* o) { o->pushdown(); if(o->ch[0] != null) print(o->ch[0]); if(o->v) cnt++, printf("%d%c", o->v, cnt == n ? '\n' : ' '); if(o->ch[1] != null) print(o->ch[1]); } void work(void) { int a, b, c; node *o, *left, *right, *mid; while(m--) { scanf("%s", s); if(s[0] == 'C') { scanf("%d%d%d", &a, &b, &c); split(root, a, left, o); split(o, b-a+1, mid, right); root = merge(left, right); split(root, c+1, left, right); root = merge(merge(left, mid), right); } else { scanf("%d%d", &a, &b); split(root, a, left, o); split(o, b-a+1, mid, right); mid->flip^=1; root = merge(merge(left, mid), right); } } } int main(void) { while(scanf("%d%d", &n, &m), n>0 || m>0) { init(); build(); splay(root, 6); work(); print(root); } return 0; }
后来还是老老实实把递归改成非递归的了。。。
#include <iostream> #include <queue> #include <stack> #include <map> #include <set> #include <bitset> #include <cstdio> #include <algorithm> #include <cstring> #include <climits> #include <cstdlib> #include <cmath> #include <time.h> #define maxn 300005 #define maxm 400005 #define eps 1e-10 #define mod 1000000007 #define lowbit(x) (x&(-x)) #define ls o<<1 #define rs o<<1 | 1 #define lson o<<1, L, mid #define rson o<<1 | 1, mid+1, R typedef long long LL; //typedef int LL; using namespace std; struct node { int s, v, flip; node *ch[2], *fa; inline int cmp(int x) { if(x == ch[0]->s+1) return -1; if(x > ch[0]->s+1) return 1; else return 0; } inline void maintain(void) { s = ch[0]->s + ch[1]->s + 1; } inline void pushdown(void) { if(!flip) return; swap(ch[0], ch[1]); ch[0]->flip^=1; ch[1]->flip^=1; flip = 0; } }*null, *root, C[maxn], *top; /* void rotate(node* &o, int d) { node* k = o->ch[d^1]; o->ch[d^1] = k->ch[d], k->ch[d] = o; o->maintain(), k->maintain(), o = k; } void splay(node* &o, int k) { o->pushdown(); int d = o->cmp(k); if(d == -1) return; if(d == 1) k -= o->ch[0]->s + 1; node* p = o->ch[d]; p->pushdown(); int d2 = p->cmp(k); int k2 = (d2 == 0 ? k : k - p->ch[0]->s - 1); if(d2 != -1) { splay(p->ch[d2], k2); if(d == d2) rotate(o, d^1); else rotate(o->ch[d], d); } rotate(o, d^1); } */ void rotate(node* &o, bool d) { node *p = o->fa; p->ch[d^1] = o->ch[d], o->fa = p->fa; if(p->fa != null) { if(p->fa->ch[0] == p) p->fa->ch[0] = o; else p->fa->ch[1] = o; } if(p->ch[d^1] != null) p->ch[d^1]->fa = p; p->fa = o, o->ch[d] = p; p->maintain(), o->maintain(); } void splay(node* &o, int k) { node *p; int d = o->cmp(k); o->pushdown(); while(d != -1) { if(d == 0) o = o->ch[0]; else { k -= o->ch[0]->s + 1; o = o->ch[1]; } o->pushdown(); d = o->cmp(k); } while(o->fa != null) { p = o->fa; if(o == p->ch[0]) { if(p != null && p->fa->ch[0] == p) rotate(p, true); rotate(o, true); } else { if(p != null && p->fa->ch[1] == p) rotate(p, false); rotate(o, false); } } } node* merge(node* left, node* right) { splay(left, left->s); left->ch[1] = right; if(right != null) right->fa = left; left->maintain(); return left; } void split(node* o, int k, node* &left, node* &right) { splay(o, k); right = o->ch[1]; o->ch[1] = null; left = o; right->fa = null; left->maintain(); } char s[10]; int n, m, cnt; void init(void) { cnt = 0; top = C; null = top++; null->fa = null; null->ch[0] = null->ch[1] = NULL; null->v = null->s = null->flip = 0; root = top++; root->ch[0] = root->ch[1] = null; root->v = root->s = root->flip = 0; root->maintain(); } void build(void) { node *p; for(int i = 1; i <= n; i++) { p = top++; root->fa = p; p->ch[0] = root; p->ch[1] = null; p->fa = null; p->v = i; p->flip = 0; p->maintain(); root = p; } } void print(node* o) { o->pushdown(); if(o->ch[0] != null) print(o->ch[0]); if(o->v) cnt++, printf("%d%c", o->v, cnt == n ? '\n' : ' '); //printf("%d\n", o->fa->v); if(o->ch[1] != null) print(o->ch[1]); } void work(void) { int a, b, c; node *o, *left, *right, *mid; while(m--) { scanf("%s", s); if(s[0] == 'C') { scanf("%d%d%d", &a, &b, &c); split(root, a, left, o); split(o, b-a+1, mid, right); root = merge(left, right); split(root, c+1, left, right); root = merge(merge(left, mid), right); } else { scanf("%d%d", &a, &b); split(root, a, left, o); split(o, b-a+1, mid, right); mid->flip^=1; root = merge(merge(left, mid), right); } } } int main(void) { while(scanf("%d%d", &n, &m), n>0 || m>0) { init(); build(); work(); print(root); } return 0; }