POJ 2892 Tunnel Warfare (SBT + stack)


题意:给定了初始的状态:有n个村庄连成一条直线,现在有三种操作: 1.摧毁一个村庄 2.询问某个村庄,输出与该村庄相连的村庄数量(包括自己) 3.修复被摧毁的村庄,优先修复最近被摧毁的..............

分析:用SBT做的话,摧毁村庄就插入,修复就移除,如果要询问的话:找到第一个大于等于该村庄编号和第一个小于等于该村庄编号的,等价于找到了联通在一起的村庄。


朴素的做法可以 set + stack + 二分 搞之.................

 

#include <iostream>

#include <algorithm>

#include <cmath>

#include <cstdio>

#include <cstdlib>

#include <cstring>

#include <string>

#include <vector>

#include <set>

#include <queue>

#include <stack>

#include <climits>//形如INT_MAX一类的

#define MAX 55555

#define INF 0x7FFFFFFF

#define REP(i,s,t) for(int i=(s);i<=(t);++i)

#define ll long long

#define mem(a,b) memset(a,b,sizeof(a))

#define mp(a,b) make_pair(a,b)

#define L(x) x<<1

#define R(x) x<<1|1

# define eps 1e-5

//#pragma comment(linker, "/STACK:36777216") ///传说中的外挂

using namespace std;



struct sbt {

    int l,r,s,key;

} tr[MAX];

int top , root;

void left_rot(int &x) {

    int y = tr[x].r;

    tr[x].r = tr[y].l;

    tr[y].l = x;

    tr[y].s = tr[x].s; //转上去的节点数量为先前此处节点的size

    tr[x].s = tr[tr[x].l].s + tr[tr[x].r].s + 1;

    x = y;

}



void right_rot(int &x) {

    int y = tr[x].l;

    tr[x].l = tr[y].r;

    tr[y].r = x;

    tr[y].s = tr[x].s;

    tr[x].s = tr[tr[x].l].s + tr[tr[x].r].s + 1;

    x = y;

}



void maintain(int &x,bool flag) {

    if(flag == 0) { //左边

        if(tr[tr[tr[x].l].l].s > tr[tr[x].r].s) {//左孩子左子树size大于右孩子size

            right_rot(x);

        } else if(tr[tr[tr[x].l].r].s > tr[tr[x].r].s) {//左孩子右子树size大于右孩子size

            left_rot(tr[x].l);

            right_rot(x);

        } else return ;

    } else { //右边

        if(tr[tr[tr[x].r].r].s > tr[tr[x].l].s) { //右孩子的右子树大于左孩子

            left_rot(x);

        } else if(tr[tr[tr[x].r].l].s > tr[tr[x].l].s) { //右孩子的左子树大于左孩子

            right_rot(tr[x].r);

            left_rot(x);

        } else return ;

    }

    maintain(tr[x].l,0);

    maintain(tr[x].r,1);

}



void insert(int &x,int key) {

    if(x == 0) { //空节点

        x = ++ top;

        tr[x].l = tr[x].r = 0;

        tr[x].s = 1;

        tr[x].key = key;

    } else {

        tr[x].s ++;

        if(key < tr[x].key) insert(tr[x].l,key);

        else insert(tr[x].r,key);

        maintain(x,key >= tr[x].key);

    }

}



int remove(int &x,int key) {

    int k;

    tr[x].s --;

    if(key == tr[x].key || (key < tr[x].key && tr[x].l == 0) || (key > tr[x].key && tr[x].r == 0)) {

        k = tr[x].key;

        if(tr[x].l && tr[x].r) {

            tr[x].key = remove(tr[x].l,tr[x].key + 1);

        } else {

            x = tr[x].l + tr[x].r;

        }

    } else if(key > tr[x].key) {

        k = remove(tr[x].r,key);

    } else if(key < tr[x].key) {

        k = remove(tr[x].l,key);

    }

    return k;

}



int pred(int &x,int y,int key)

//前驱 小于

{

    if(x == 0) return tr[y].key ;

    if(tr[x].key < key) return pred(tr[x].r,x,key);

    else if(tr[x].key > key) return pred(tr[x].l,y,key);

    else return key;

}//pred(root,0,key)

int succ(int &x,int y,int key) { //后继 大于

    if(x == 0) return tr[y].key;

    if(tr[x].key > key) return succ(tr[x].l,x,key);

    else if(tr[x].key < key) return succ(tr[x].r,y,key);

    else return key;

}





int n,m;

char c;

int st[MAX];

int head = 0;

int main() {

    root = 0;

    top = 0;

    int b;

    scanf("%d%d",&n,&m);

    for(int i=0; i<m; i++) {

        cin >> c;

        if(c == 'D') {

            scanf("%d",&b);

            st[head++] = b;

            insert(root,b);

        }

        if(c == 'R') {

            remove(root,st[--head]);

        }

        if(c == 'Q') {

            scanf("%d",&b);

            int pre = pred(root,0,b);

            int suc = succ(root,0,b);

            if(suc == 0) suc = n+1;

            if(pre == suc) {

                puts("0");

                continue;

            }

            printf("%d\n",suc - pre - 1);

        }

    }

    return 0;

}


 


 

你可能感兴趣的:(stack)