【splay tree】 HDOJ Queue-jumpers

我只能说这种类型的题目比赛的时候写一定要非常小心。。。调死我了。。。

注意到题目中n的范围是10的8次方,但是q的范围却是10的5次方。。。显然要用到离散化。。。先离线,然后把top和query的点看做一个点。。其余的点缩点。。。就这样子用splay维护一下就行了。。。

#include <iostream>  
#include <queue>  
#include <stack>  
#include <map>  
#include <set>  
#include <bitset>  
#include <cstdio>  
#include <algorithm>  
#include <cstring>  
#include <climits>  
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 400005
#define maxm 400005
#define eps 1e-10
#define mod 1000000007 
#define INF 99999999 
#define lowbit(x) (x&(-x))  
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid  
#define rson o<<1 | 1, mid+1, R  
typedef long long LL;
//typedef int LL;
using namespace std;

struct opr
{
	char p;
	int v;
}op[maxn];
struct point
{
	int a, b;
}po[maxn];
int xx[maxn], x[maxn];
int xcnt, ycnt, n, m;
char s[100];
struct node
{
	int s, v, sum, a, b;
	node *ch[2], *fa;
	int cmp(int x)
	{
		if(x == ch[0]->s + 1) return -1;
		if(x > ch[0]->s + 1) return 1;
		else return 0;
	}
	void maintain(void)
	{
		s = ch[0]->s + ch[1]->s + 1;
		sum = ch[0]->sum + ch[1]->sum + v;
	}
}*null, *root, C[maxn], *top, *loc[maxn];
void rotate(node* &o, bool d)
{
	node *p = o->fa; p->ch[d^1] = o->ch[d], o->fa = p->fa;
	if(p->fa != null) {
		if(p->fa->ch[0] == p) p->fa->ch[0] = o;
		else p->fa->ch[1] = o;
	}
	if(p->ch[d^1] != null) p->ch[d^1]->fa = p;
	p->fa = o, o->ch[d] = p;
	p->maintain(), o->maintain();
}
void splay(node* &o)
{
	node *p;
	while(o->fa != null) {
		p = o->fa;
		if(p->ch[0] == o) {
			if(p != null && p->fa->ch[0] == p)
				rotate(p, true);
			rotate(o, true);
		}
		else {
			if(p != null && p->fa->ch[1] == p)
				rotate(p, false);
			rotate(o, false);
		}
	}
}
void kth(node* &o, int k)
{
	int d = o->cmp(k);
	while(d != -1) {
		if(d) {
			k -= o->ch[0]->s + 1;
			o = o->ch[1];
		}
		else o = o->ch[0];
		d = o->cmp(k);
	}
	splay(o);
}
node* merge(node *left, node *right)
{
	left->ch[1] = right;
	right->fa = left;
	left->maintain();
	return left;
}
void split(node *o, node* &left, node* &right)
{
	right = o->ch[1];
	right->fa = null;
	o->ch[1] = null;
	left = o;
	left->maintain();
}
void find(node* o, int k)
{
	while(!(o->ch[0]->sum < k && o->ch[0]->sum + o->v >= k)) {
		if(k > o->ch[0]->sum + o->v) {
			k -= o->ch[0]->sum + o->v;
			o = o->ch[1];
		}
		else o = o->ch[0];
	}
	k -= o->ch[0]->sum;
	printf("%d\n", k + o->a - 1);
}

int cmp(int a, int b)
{
	return a < b;
}
node* newnode(void)
{
	top++;
	top->s = 1;
	top->v = top->sum = 0;
	top->ch[0] = top->ch[1] = top->fa = null;
	return top;
}
void init(void)
{
	top = C;
	null = top++;
	null->s = null->v = null->sum = 0;
	null->ch[0] = null->ch[1] = null->fa = null;
	root = newnode();
	root->a = root->b = 0;
}
void read(void)
{
	int i, j;
	xcnt = ycnt = 0;
	scanf("%d%d", &n, &m);
	for(i = 1; i <= m; i++) {
		scanf("%s%d", s, &op[i].v);
		op[i].p = s[0];
		if(s[0] == 'T' || s[0] == 'Q')
			xx[++xcnt] = op[i].v;
	}
	int b = 1;
	sort(xx+1, xx+xcnt+1, cmp);
	for(i = 2, j = 2; i <= xcnt; i++)
		if(xx[i] != xx[i-1])
			xx[j++] = xx[i];
	xcnt = j - 1;
	for(int i = 1; i <= xcnt; i++) {
		if(xx[i] > b) {
			++ycnt;
			po[ycnt].a  = b;
			po[ycnt].b = xx[i]-1;
		}
		++ycnt;
		po[ycnt].a = po[ycnt].b = xx[i];
		b = xx[i] + 1;
		x[i] = ycnt;
	}
	if(n >= b) {
		++ycnt;
		po[ycnt].a = b;
		po[ycnt].b = n;
	}
}
int search(int tmp)
{
	int bot = 1, top = xcnt, mid;
	while(top >= bot) {
		mid = (bot+top)>>1;
		if(xx[mid] == tmp) break;
		if(xx[mid] > tmp) top = mid-1;
		else bot = mid+1;
	}
	return x[mid];
}
void debug(node *o)
{
	if(o->ch[0] != null) debug(o->ch[0]);
	printf("AA %d %d %d %d BB\n", o->v, o->a, o->b, o->sum);
	if(o->ch[1] != null) debug(o->ch[1]);
}
void build(void)
{
	node *p;
	for(int i = 1; i <= ycnt; i++) {
		p = newnode();
		p->v = po[i].b - po[i].a + 1;
		p->a = po[i].a , p->b = po[i].b;
		loc[i] = p;
		p->ch[0] = root;
		root->fa = p;
		root = p;
		root->maintain();
	}
}
void work(void)
{
	int now, ans;
	node *mid, *left, *right, *p;
	for(int i = 1; i <= m; i++) {
		if(op[i].p == 'T') {
			now = search(op[i].v);
			root = loc[now];
			splay(root);
			kth(root, root->ch[0]->s);
			split(root, left, p);
			kth(p, 1);
			split(p, mid, right);
			p = merge(left, right);
			kth(p, 1);
			split(p, left, right);
			p = merge(left, mid);
			kth(p, 2);
			root = merge(p, right);
		}
		if(op[i].p == 'Q') {
			now = search(op[i].v);
			root = loc[now];
			splay(root);
			ans = root->ch[0]->sum;
			ans += op[i].v - root->a + 1;
			printf("%d\n", ans);
		}
		if(op[i].p == 'R') {
			find(root, op[i].v);
		}
	}
}
int main(void)
{
	int _, __;
	while(scanf("%d", &_)!=EOF) {
		__ = 0;
		while(_--) {
			init();
			read();
			build();
			printf("Case %d:\n", ++__);
			work();
		}
	}
	return 0;
}


你可能感兴趣的:(HDU)