bzoj 3196: Tyvj 1730 二逼平衡树

3196: Tyvj 1730 二逼平衡树

Time Limit: 10 Sec   Memory Limit: 128 MB
Submit: 2276   Solved: 937
[ Submit][ Status][ Discuss]

Description

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)

Input

第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继

Output

对于操作1,2,4,5各输出一行,表示查询结果

Sample Input

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

Sample Output

2
4
3
4
9

HINT

1.n和m的数据范围:n,m<=50000


2.序列中每个数的数据范围:[0,1e8]


3.虽然原题没有,但事实上5操作的k可能为负数

Source

[ Submit][ Status][ Discuss]


题解:线段树套splay

先建立一棵区间线段树,然后线段树中的每个节点建立一棵位置在当前点表示的区间的权值splay(小的在左儿子,大的在右儿子)

solve1: 直接把所有在范围内的区间内比他小的数统计一下,然后+1

solve2: 二分答案,通过solve1,计算mid在区间中的排名

solve3: 把这个位置原本的数从所有包含这个位置的区间中删去,然后加入新的值

solve4:从所有在范围内的区间(给出的范围可能在线段树中跨越多个区间)中找前驱最大的

solve5:从所有在范围内的区间(给出的范围可能在线段树中跨越多个区间)中找后继最小的

思路非常清晰明了,但是实现起来异常的麻烦,各种手残简直鬼畜。

刚开始姿势不够优越,TLE。改了姿势后刚好过了,"9724 ms"。。。。。

<span style="font-size:18px;">#include<iostream>  
#include<cstdio>  
#include<cstring>  
#include<algorithm>  
#include<cmath>  
#define N 2000003  
#define M 50003  
using namespace std;  
int n,m;  
int ls[4*M],rs[4*M],root[4*M],pd;  
int ch[N][3],fa[N],maxn,a[N],size[N],key[N],cnt[N],sz;  
void  clear(int x)  
{  
    size[x]=key[x]=cnt[x]=ch[x][1]=ch[x][0]=fa[x]=0;  
}  
int get(int x)  
{  
    return ch[fa[x]][1]==x;  
}  
void update(int x)  
{  
    size[x]=cnt[x];  
    if (ch[x][0]) size[x]+=size[ch[x][0]];  
    if (ch[x][1]) size[x]+=size[ch[x][1]];  
}  
void rotate(int x)  
{  
    int y=fa[x]; int z=fa[y]; int which=get(x);  
    if (z)  
     ch[z][ch[z][1]==y]=x;  
    fa[x]=z; ch[y][which]=ch[x][which^1]; fa[ch[y][which]]=y;  
    ch[x][which^1]=y; fa[y]=x;  
    update(y); update(x);  
}  
void splay(int x,int &root)  
{  
    for (int f;(f=fa[x]);rotate(x))  
     if (fa[f])  
      rotate(get(x)==get(f)?f:x);  
    root=x;  
}  
void insert(int &root,int x)  
{  
    if (!root)  
     {  
        root=++sz; clear(sz);  
        size[sz]=cnt[sz]=1; key[sz]=x;  
        return;  
     }  
    int f=0; int now=root;  
    while(true)  
    {  
        if (x==key[now])  
         {  
            cnt[now]++; update(now); splay(now,root); return;  
         }  
        f=now;  
        now=ch[now][key[now]<x];  
        if (!now)  
        {  
            sz++; clear(sz);  
            key[sz]=x; cnt[sz]=size[sz]=1; fa[sz]=f; ch[f][key[f]<x]=sz;  
            update(f);  splay(sz,root); return;  
        }  
    }  
}  
int find(int x,int &root)  //查找x的位置
{  
    int now=root;  
    while (true)  
     {  
        if (now==0) return 0;
        if (key[now]==x)  
         {  
            splay(now,root);  
            return now;  
         }  
        if (x<key[now])  
         now=ch[now][0];  
        if (x>key[now])  
         now=ch[now][1];  
     }  
}  
int findx(int x,int &root)  //查找当前区间内比x小的数有多少个
{  
  int now=root; int ans=0;    
  while (true)    
   {   
    if (!now)  return ans; 
    if (x<key[now])    
     now=ch[now][0];    
    else   
     {    
       ans+=(ch[now][0]?size[ch[now][0]]:0);    
       if (x==key[now])     
        {    
         splay(now,root); pd=true; return ans;    
        }    
       ans+=cnt[now];    
       now=ch[now][1];    
     }    
   }    
}  
int pre(int root)  
{  
    int now=ch[root][0];  
    while (ch[now][1]) now=ch[now][1];  
    return now;   
}  
int next(int root)  
{  
    int now=ch[root][1];  
    while (ch[now][0]) now=ch[now][0];  
    return now;  
}  
void del(int &root,int x)  
{  
    splay(x,root);  
    if (cnt[root]>1)  
     {  
        cnt[root]--; update(root); return;  
     }  
    if (!ch[root][1]&&!ch[root][0])  
     {  
        clear(root); root=0; return ;  
     }  
    if (!ch[root][1])  
     {  
        int old=root; root=ch[root][0]; fa[root]=0; clear(old); return;  
     }  
    if (!ch[root][0])  
     {  
        int old=root; root=ch[root][1]; fa[root]=0; clear(old); return;  
     }  
    int k=pre(root); int old=root;  splay(k,root);  
    ch[k][1]=ch[old][1]; fa[ch[k][1]]=k; clear(old);  
    update(k); return;  
}  
void pointchange(int now,int l,int r,int x,int v)  
{  
    insert(root[now],v);  
    if (l==r) return;  
    int mid=(l+r)/2;  
    if (x<=mid)  
     pointchange(now<<1,l,mid,x,v);  
    else 
     pointchange(now<<1|1,mid+1,r,x,v);  
}  
int solve1(int now,int l,int r,int ll,int rr,int k)  
{  
    if (l>=ll&&r<=rr)  
     {  
        return findx(k,root[now]);  
     }  
    int mid=(l+r)/2;  
    int ans=0;  
    if (ll<=mid)  
     ans+=solve1(now<<1,l,mid,ll,rr,k);  
    if (rr>mid)  
     ans+=solve1(now<<1|1,mid+1,r,ll,rr,k);  
    return ans;  
}
void  solve2(int l,int r,int k)
{
    int head=0,tail=maxn+1,mid;
    while (head<tail)
    {
        mid=(head+tail)/2;
        if (solve1(1,1,n,l,r,mid)<k)
            head=mid+1;
        else tail=mid;
    }
    printf("%d\n",head-1);
} 
void solve3(int now,int l,int r,int x,int v)  
{  
    int t=find(a[x],root[now]); 
    del(root[now],t);  
    insert(root[now],v);  
    if (l==r) return;  
    int mid=(l+r)/2;  
    if (x<=mid)  
     solve3(now<<1,l,mid,x,v);  
    else 
     solve3(now<<1|1,mid+1,r,x,v);  
} 
int find_next_min(int rt,int x)
{
    int now=root[rt],t=0,ans=-1;
    while (now)
    {
        if (key[now]<x)
        {
            if (ans<key[now])ans=key[now];
            now=ch[now][1];
        }
        else now=ch[now][0];
    }
    return ans;
}
int find_next_max(int rt,int x)
{
    int now=root[rt],t=0,ans=1000000000;
    while (now)
    {
        if (key[now]>x) 
        {
            if (ans>key[now]) ans=key[now];
            now=ch[now][0];
        }
        else now=ch[now][1];
    }
    return ans;
} 
int solve4(int now,int l,int r,int ll,int rr,int x)  
{  
    if (l>=ll&&r<=rr)  
     {  
       return  find_next_min(now,x);
     }  
     int mid=(l+r)/2;  
     int maxn=0;  
     if (ll<=mid)  
      maxn=max(maxn,solve4(now<<1,l,mid,ll,rr,x));  
     if (rr>mid)  
      maxn=max(maxn,solve4(now<<1|1,mid+1,r,ll,rr,x));  
     return maxn;  
}  
int solve5(int now,int l,int r,int ll,int rr,int x)  
{  
    if (l>=ll&&r<=rr)  
     {  
       return  find_next_max(now,x);
     }  
     int mid=(l+r)/2;  
     int minn=1000000000;  
     if (ll<=mid)  
      minn=min(minn,solve5(now<<1,l,mid,ll,rr,x));  
     if (rr>mid)  
      minn=min(minn,solve5(now<<1|1,mid+1,r,ll,rr,x));  
     return minn;  
}  
int main()  
{   
    freopen("input.txt","r",stdin);
    freopen("my.out","w",stdout);
    scanf("%d%d",&n,&m);  
    for (int i=1;i<=n;i++)  
     scanf("%d",&a[i]),maxn=max(maxn,a[i]);  
    for (int i=1;i<=n;i++)  
     pointchange(1,1,n,i,a[i]);  
    for (int i=1;i<=m;i++)  
     {  
        int op,x,y,k; scanf("%d%d%d",&op,&x,&y);  
        if (op!=3) scanf("%d",&k);  
        switch(op)  
        {  
            case 1: printf("%d\n",solve1(1,1,n,x,y,k)+1); break; //注意rank是比他小的个数+1 
            case 2: solve2(x,y,k); break;  
            case 3: solve3(1,1,n,x,y); maxn=max(maxn,y); a[x]=y; break;  
            case 4: printf("%d\n",solve4(1,1,n,x,y,k)); break;  
            case 5: printf("%d\n",solve5(1,1,n,x,y,k)); break;  
        }  
     }  
}  </span>


你可能感兴趣的:(bzoj 3196: Tyvj 1730 二逼平衡树)