POJ 2761 SBT(size balanced tree)

题意:给出n个数字m个询问,每次询问在范围[a,b]内第k小数字是多少,并输出。

思路:网上的最优解法是划分树或者归并树,最近学了SBT(size balanced tree),正好里面有类似的select(t,k)选择第k小数的操作,所以就用SBT写了一下。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <string>
#include <cmath>
#include <cstring>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <iomanip>
#define PI acos(-1.0)
#define Max 1000005
#define inf 1<<28
#define LL(x) (x<<1)
#define RR(x) (x<<1|1)
#define FOR(i,s,t) for(int i=(s);i<=(t);++i)
#define ll long long
#define mem(a,b) memset(a,b,sizeof(a))
#define mp(a,b) make_pair(a,b)
using namespace std;

struct SBT
{
    int left , right , num ,size;
}tree[Max];

void left_rot(int &x)//左旋
{
    int y = tree[x].right;
    tree[x].right = tree[y].left;
    tree[y].left = x;
    tree[y].size = tree[x].size;
    tree[x].size = tree[tree[x].left].size + tree[tree[x].right].size + 1;
    x = y;
}

void right_rot(int &x)//右旋
{
    int y = tree[x].left;
    tree[x].left = tree[y].right;
    tree[y].right = x;
    tree[y].size = tree[x].size;
    tree[x].size = tree[tree[x].left].size + tree[tree[x].right].size + 1;
    x = y;
}

void maintain(int &x,bool flag)//更新
{
    if(!flag)//左子树是0
    {
        if(tree[tree[tree[x].left].left].size > tree[tree[x].right].size)
        right_rot(x);
        else if(tree[tree[tree[x].left].right].size > tree[tree[x].right].size)
        {
            left_rot(tree[x].left);
            right_rot(x);
        }
        else return ;
    }
    else//右子树是1
    {
        if(tree[tree[tree[x].right].right].size > tree[tree[x].left].size)
        left_rot(x);
        else if(tree[tree[tree[x].right].left].size > tree[tree[x].left].size)
        {
            right_rot(tree[x].right);
            left_rot(x);
        }
        else return ;
    }
    maintain(tree[x].left,0);
    maintain(tree[x].right,1);
    maintain(x,1);
    maintain(x,0);
}

int root , top;
void insert(int &x,int num)//插入
{
    if(x == 0)
    {
        x = ++top;
        tree[x].left = tree[x].right = 0;
        tree[x].size = 1;
        tree[x].num = num ;
    }
    else
    {
        tree[x].size ++;
        if(num < tree[x].num )insert(tree[x].left , num);
        else insert(tree[x].right , num);
        maintain(x,num >= tree[x].num);
    }
}

int select(int &x,int num)//选择第num个小数
{
    int r = tree[tree[x].left].size + 1;
    if(r == num )return tree[x].num;
    else if (r > num)return select(tree[x].left,num);
    else return select(tree[x].right,num - r);
}
int del(int &x,int num)//删除num值
{
    int d_num ;
    if(!x)return 0;
    tree[x].size --;
    if(num == tree[x].num ||(tree[x].num > num && tree[x].left == 0)||(tree[x].num < num && tree[x].right == 0))
    {
        d_num = tree[x].num;
        if( tree[x].left && tree[x].right )
        {
            tree[x].num = del(tree[x].left,tree[x].num + 1);
        }
        else
        {
            x = tree[x].left + tree[x].right ;
        }
    }
    else if(num > tree[x].num)
    d_num = del(tree[x].right , num);
    else if(num < tree[x].num )
    d_num = del(tree[x].left ,num);
    return d_num;
}
int num[100005];
struct kdq
{
    int s,e,idx,ans,th;
}ans[50005];//代表询问,[s,e]是区间.th是第几,ans是输出的结果,idx是该询问的次序。
bool cmp(kdq x,kdq y)//按询问区间从前到后排
{
    if(x.s == y.s)
    return x.e<y.e;
    return x.s<y.s;
}

bool cmp1(kdq x,kdq y)//按询问次序排,最后按顺序输出。
{
    return x.idx<y.idx;
}

int main()
{
    int n , m;
    while(scanf("%d%d",&n,&m) != EOF)
    {
        for (int i = 1 ; i <= n ;i ++)scanf("%d",&num[i]);
        for (int i = 0 ;i < m ;i ++){
            scanf("%d%d%d",&ans[i].s,&ans[i].e,&ans[i].th);
            ans[i].idx = i ;
        }
        sort(ans,ans + m ,cmp);
        root = top =0;
        for (int i = ans[0].s ; i <=ans[0].e ; i ++)insert(root , num[i]);
        ans[0].ans = select(root , ans[0].th);
        for (int i = 1; i < m ;i ++)
        {
            if(ans[i].s >= ans[i-1].e)//如果后面区间和前面的区间完全没有交集,则删除所有前面区间元素,插入所有后面区间的元素。
            {
                for (int j = ans[i-1].s ; j <= ans[i-1].e ;j ++)del(root , num[j]);
                for (int j = ans[i].s ; j <= ans[i].e ; j ++)insert(root ,num[j]);
            }
            else if(ans[i].e <= ans[i-1].e)//如果后面的区间包含于前面的区间,则删除前面区间多余的元素
            {
                for (int j = ans[i-1].s ; j < ans[i].s ; j ++)del(root ,num[j]);
                for (int j = ans[i].e + 1 ; j <=ans[i-1].e ;j ++ )del(root,num[j]);
            }
            else//如果后面的区间和前面的区间有部分交集,则删除前面区间多余的元素,插入后面区间多出的元素
            {
                for (int j = ans[i-1].s ; j < ans[i].s ; j ++)del(root ,num[j]);
                for (int j = ans[i-1].e + 1; j <= ans[i].e ;j ++)insert(root ,num[j]);
            }
//这里解释有点繁琐,但是不难想的
            ans[i].ans = select(root ,ans[i].th);//记录结果
        }
        sort(ans,ans + m ,cmp1);
        for (int i = 0 ;i < m ; i ++)printf("%d\n",ans[i].ans);
    }
    return 0;
}
同类型的题POJ 2104,用此方法TLE了。继续去学习其他树。。



你可能感兴趣的:(POJ 2761 SBT(size balanced tree))