【splay tree】 HDOJ 3487 Play with Chain

伸展树的基本操作。。。但是题目数据量比较大,加了预编译指令用C++交才勉强过。。。

#include <iostream>  
#include <queue>  
#include <stack>  
#include <map>  
#include <set>  
#include <bitset>  
#include <cstdio>  
#include <algorithm>  
#include <cstring>  
#include <climits>  
#include <cstdlib>
#include <cmath>
#include <time.h>
#pragma comment(linker, "/STACK:16777216")
#define maxn 300005
#define maxm 400005
#define eps 1e-10
#define mod 1000000007  
#define lowbit(x) (x&(-x))  
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid  
#define rson o<<1 | 1, mid+1, R  
typedef long long LL;
//typedef int LL;
using namespace std;

struct node
{
	int s, v, flip;
	node* ch[2];
	inline int cmp(int x)
	{
		if(x == ch[0]->s+1) return -1;
		if(x > ch[0]->s+1) return 1;
		else return 0;
	}
	inline void maintain(void)
	{
		s = ch[0]->s + ch[1]->s + 1;
	}
	inline void pushdown(void)
	{
		if(!flip) return;
		swap(ch[0], ch[1]);
		ch[0]->flip^=1;
		ch[1]->flip^=1;
		flip = 0;
	}
}*null, *root, C[maxn], *top;

void rotate(node* &o, int d)
{
	node* k = o->ch[d^1]; o->ch[d^1] = k->ch[d], k->ch[d] = o;
	o->maintain(), k->maintain(), o = k;
}
void splay(node* &o, int k)
{
	o->pushdown();
	int d = o->cmp(k);
	if(d == -1) return;
	if(d == 1) k -= o->ch[0]->s + 1;
	node* p = o->ch[d];
	p->pushdown();
	int d2 = p->cmp(k);
	int k2 = (d2 == 0 ? k : k - p->ch[0]->s - 1);
	if(d2 != -1) {
		splay(p->ch[d2], k2);
		if(d == d2) rotate(o, d^1);
		else rotate(o->ch[d], d);
	}
	rotate(o, d^1);
}
node* merge(node* left, node* right)
{
	splay(left, left->s);
	left->ch[1] = right;
	left->maintain();
	return left;
}
void split(node* o, int k, node* &left, node* &right)
{
	splay(o, k);
	right = o->ch[1];
	o->ch[1] = null;
	left = o;
	left->maintain();
}
char s[10];
int n, m, cnt;

void init(void)
{
	cnt = 0;
	top = C;
	null = top++;
	null->ch[0] = null->ch[1] = NULL;
	null->v = null->s = null->flip = 0;
	root = top++;
	root->ch[0] = root->ch[1] = null;
	root->v = root->s = root->flip = 0;
	root->maintain();
}
void build(void)
{
	node *p;
	for(int i = 1; i <= n; i++) {
		p = top++;
		p->ch[0] = root;
		p->ch[1] = null;
		p->v = i;
		p->flip = 0;
		p->maintain();
		root = p;
	}
}
void print(node* o)
{
	o->pushdown();
	if(o->ch[0] != null) print(o->ch[0]);
	if(o->v) cnt++, printf("%d%c", o->v, cnt == n ? '\n' : ' ');
	if(o->ch[1] != null) print(o->ch[1]);
}
void work(void)
{
	int a, b, c;
	node *o, *left, *right, *mid;
	while(m--) {
		scanf("%s", s);
		if(s[0] == 'C') {
			scanf("%d%d%d", &a, &b, &c);
			split(root, a, left, o);
			split(o, b-a+1, mid, right);
			root = merge(left, right);
			split(root, c+1, left, right);
			root = merge(merge(left, mid), right);
		}
		else {
			scanf("%d%d", &a, &b);
			split(root, a, left, o);
			split(o, b-a+1, mid, right);
			mid->flip^=1;
			root = merge(merge(left, mid), right);
		}
	}
}
int main(void)
{
	while(scanf("%d%d", &n, &m), n>0 || m>0) {
		init();
		build();
		splay(root, 6);
		work();
		print(root);
	}
	return 0;
}

后来还是老老实实把递归改成非递归的了。。。


#include <iostream>  
#include <queue>  
#include <stack>  
#include <map>  
#include <set>  
#include <bitset>  
#include <cstdio>  
#include <algorithm>  
#include <cstring>  
#include <climits>  
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 300005
#define maxm 400005
#define eps 1e-10
#define mod 1000000007  
#define lowbit(x) (x&(-x))  
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid  
#define rson o<<1 | 1, mid+1, R  
typedef long long LL;
//typedef int LL;
using namespace std;

struct node
{
	int s, v, flip;
	node *ch[2], *fa;
	inline int cmp(int x)
	{
		if(x == ch[0]->s+1) return -1;
		if(x > ch[0]->s+1) return 1;
		else return 0;
	}
	inline void maintain(void)
	{
		s = ch[0]->s + ch[1]->s + 1;
	}
	inline void pushdown(void)
	{
		if(!flip) return;
		swap(ch[0], ch[1]);
		ch[0]->flip^=1;
		ch[1]->flip^=1;
		flip = 0;
	}
}*null, *root, C[maxn], *top;
/*
void rotate(node* &o, int d)
{
	node* k = o->ch[d^1]; o->ch[d^1] = k->ch[d], k->ch[d] = o;
	o->maintain(), k->maintain(), o = k;
}
void splay(node* &o, int k)
{
	o->pushdown();
	int d = o->cmp(k);
	if(d == -1) return;
	if(d == 1) k -= o->ch[0]->s + 1;
	node* p = o->ch[d];
	p->pushdown();
	int d2 = p->cmp(k);
	int k2 = (d2 == 0 ? k : k - p->ch[0]->s - 1);
	if(d2 != -1) {
		splay(p->ch[d2], k2);
		if(d == d2) rotate(o, d^1);
		else rotate(o->ch[d], d);
	}
	rotate(o, d^1);
}
*/
void rotate(node* &o, bool d)
{
	node *p = o->fa; p->ch[d^1] = o->ch[d], o->fa = p->fa;
	if(p->fa != null) {
		if(p->fa->ch[0] == p) p->fa->ch[0] = o;
		else p->fa->ch[1] = o;
	}
	if(p->ch[d^1] != null) p->ch[d^1]->fa = p;
	p->fa = o, o->ch[d] = p;
	p->maintain(), o->maintain();
}
void splay(node* &o, int k)
{
	node *p;
	int d = o->cmp(k);
	o->pushdown();
	while(d != -1) {
		if(d == 0) o = o->ch[0];
		else {
			k -= o->ch[0]->s + 1;
			o = o->ch[1];
		}
		o->pushdown();
		d = o->cmp(k);
	}
	while(o->fa != null) {
		p = o->fa;
		if(o == p->ch[0]) {
			if(p != null && p->fa->ch[0] == p)
				rotate(p, true);
			rotate(o, true);
		}
		else {
			if(p != null && p->fa->ch[1] == p)
				rotate(p, false);
			rotate(o, false);
		}
	}
}
node* merge(node* left, node* right)
{
	splay(left, left->s);
	left->ch[1] = right;
	if(right != null) right->fa = left;
	left->maintain();
	return left;
}
void split(node* o, int k, node* &left, node* &right)
{
	splay(o, k);
	right = o->ch[1];
	o->ch[1] = null;
	left = o;
	right->fa = null;
	left->maintain();
}
char s[10];
int n, m, cnt;

void init(void)
{
	cnt = 0;
	top = C;
	null = top++;
	null->fa = null;
	null->ch[0] = null->ch[1] = NULL;
	null->v = null->s = null->flip = 0;
	root = top++;
	root->ch[0] = root->ch[1] = null;
	root->v = root->s = root->flip = 0;
	root->maintain();
}
void build(void)
{
	node *p;
	for(int i = 1; i <= n; i++) {
		p = top++;
		root->fa = p;
		p->ch[0] = root;
		p->ch[1] = null;
		p->fa = null;
		p->v = i;
		p->flip = 0;
		p->maintain();
		root = p;
	}
}
void print(node* o)
{
	o->pushdown();
	if(o->ch[0] != null) print(o->ch[0]);
	if(o->v) cnt++, printf("%d%c", o->v, cnt == n ? '\n' : ' ');
	//printf("%d\n", o->fa->v);
	if(o->ch[1] != null) print(o->ch[1]);
}
void work(void)
{
	int a, b, c;
	node *o, *left, *right, *mid;
	while(m--) {
		scanf("%s", s);
		if(s[0] == 'C') {
			scanf("%d%d%d", &a, &b, &c);
			split(root, a, left, o);
			split(o, b-a+1, mid, right);
			root = merge(left, right);
			split(root, c+1, left, right);
			root = merge(merge(left, mid), right);
		}
		else {
			scanf("%d%d", &a, &b);
			split(root, a, left, o);
			split(o, b-a+1, mid, right);
			mid->flip^=1;
			root = merge(merge(left, mid), right);
		}
	}
}
int main(void)
{
	while(scanf("%d%d", &n, &m), n>0 || m>0) {
		init();
		build();
		work();
		print(root);
	}
	return 0;
}


你可能感兴趣的:(HDU)