【BZOJ3196】【Tyvj1730】二逼平衡树,第一次的树套树(线段树+splay)

传送门1
传送门2
写在前面:创造迄今最长的正常代码的记录
思路:个人感觉这个树套树就是对线段树的每个区间建一棵splay来维护,最初觉得这个方法会爆T爆M……(实际上真的可能会爆)。对于5个操作,我们有如下策略
对于操作1,我们比较容易想到,寻找k在[l,r]上的排名就是求[l,r]中比k小的数的数量+1,这等价于找出它在[l,mid]和[mid+1,r]上比他小的数的总数量+1,然后就可以线段树一层层套下去,再用Splay的rank函数查找了
对于操作2,这是一个比较麻烦的,因为它不能像1一样在区间中合并,但数据范围是[0,10^8],所以我们可以令l=0,r=10^8,二分查找mid的排名,最后得到正解
对于操作3,这是单点修改,所以直接一直放下去,并修改所在的个区间的splay(先del原数值再insert新数值)
对于操作4,5,显然答案也是可以在区间上合并的,前驱找出最大的,后继找出最小的,一层层下放,如果查询区间覆盖了当前线段树的节点区间,就直接调用Splay的前驱后继函数
注意:
1.记录每个splay的根节点并需要实时修改,推荐通过记录其下标来修改(代码中rt全部为当前splay的根在数组中的下标),毕竟取地址符什么看起来就很不舒服= =
2.废物利用,每次修改操作时记录下原数在splay中的下标,到时候插入的时候直接用就行了(记得初始化),防止下标加的过多导致RE(如果原数出现次数大于1则不能这么做,只能再开一个下标)
3.代码在BZOJ上测试通过,但截至发文时间,Tyvj服务器一直处于崩溃,无法评测,po主在cogs上评测T了两个点……开O2加inline也不管用……

#include<bits/stdc++.h>
#define pd(i) (i>='0'&&i<='9')
using namespace std;
int n,m,tot;
int num[50003],roots[2000003];
struct Splay
{
    int fa,ch[2],siz,data,occ;
}a[2000003];
int in()
{
    int t=0,f=1;
    char ch=getchar();
    while (!pd(ch))
    {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (pd(ch)) t=(t<<3)+(t<<1)+ch-'0',ch=getchar();
    return f*t;
}
void ct(int x)
{
    a[x].siz=a[a[x].ch[0]].siz+a[a[x].ch[1]].siz+a[x].occ;
}
void made(int x,int id)
{
    a[id].data=x,
    a[id].occ=a[id].siz=1,
    a[id].ch[0]=a[id].ch[1]=a[id].fa=0;
}
void rorate(int now,bool mk)
{
    int pa=a[now].fa;
    a[a[now].ch[mk]].fa=pa;
    a[pa].ch[!mk]=a[now].ch[mk];
    a[now].fa=a[pa].fa;
    if (a[pa].fa)
    {
        if (a[a[pa].fa].ch[0]==pa) a[a[pa].fa].ch[0]=now;
        else a[a[pa].fa].ch[1]=now;
    }
    a[now].ch[mk]=pa;
    a[pa].fa=now;
    ct(pa);ct(now);
}
void splay(int rt,int now,int goal)
{
    int pa;
    while (a[now].fa!=goal)
    {
        pa=a[now].fa;
        if (a[pa].fa==goal)
        {
            if (a[pa].ch[0]==now) rorate(now,1);
            else rorate(now,0);
        }
        else if (a[a[pa].fa].ch[0]==pa)
        {
            if (a[pa].ch[0]==now) rorate(pa,1);
            else rorate(now,0);
            rorate(now,1);
        }
        else
        {
            if (a[pa].ch[1]==now) rorate(pa,0);
            else rorate(now,1);
            rorate(now,0);
        }
    }
    if (!goal) roots[rt]=now;
}
void insert(int rt,int x,int id)
{
    if (!roots[rt]) {made(x,id);roots[rt]=id;return;}
    int now=roots[rt];
    while (now)
    {
        if (a[now].data==x) {a[now].occ++;a[now].siz++;splay(rt,now,0);return;}
        if (a[now].data>x)
        {
            if (!a[now].ch[0]) {made(x,id);a[now].ch[0]=id;a[id].fa=now;break;}
            else now=a[now].ch[0];
        }
        else
        {
            if (!a[now].ch[1]) {made(x,id);a[now].ch[1]=id;a[id].fa=now;break;}
            else now=a[now].ch[1];
        }
    }
    splay(rt,id,0);
}
int find(int root,int x)
{
    int now=root;
    while (now)
    {
        if (a[now].data==x) return now;
        if (a[now].data>x) now=a[now].ch[0];
        else now=a[now].ch[1];
    }
}
int findmax(int now)
{
    while (a[now].ch[1]) now=a[now].ch[1];
    return now;
}
int find_next_min(int rt,int x)
{
    int now=roots[rt],t=0,ans=-0x7fffffff;
    while (now)
    {
        if (a[now].data<x)
        {
            if (ans<a[now].data)ans=a[now].data,t=now;
            now=a[now].ch[1];
        }
        else now=a[now].ch[0];
    }
    return ans;
}
int find_next_max(int rt,int x)
{
    int now=roots[rt],t=0,ans=0x7fffffff;
    while (now)
    {
        if (a[now].data>x) 
        {
            if (ans>a[now].data) ans=a[now].data,t=now;
            now=a[now].ch[0];
        }
        else now=a[now].ch[1];
    }
    return ans;
}
void replace(int rt,int x,int k)
{
    int now=find(roots[rt],x);
    splay(rt,now,0);
    if (a[now].occ>1) {a[now].occ--;a[now].siz--;}
    else if (a[now].siz==1) roots[rt]=0;
    else if (!a[now].ch[0])
    {
        roots[rt]=a[now].ch[1];
        a[a[now].ch[1]].fa=0;
    }
    else if (!a[now].ch[1])
    {
        roots[rt]=a[now].ch[0];
        a[a[now].ch[0]].fa=0;
    }
    else
    {
        splay(rt,findmax(a[now].ch[0]),now);
        a[a[now].ch[0]].ch[1]=a[now].ch[1];
        a[a[now].ch[1]].fa=a[now].ch[0];
        a[a[now].ch[0]].fa=0;
        roots[rt]=a[now].ch[0];
        ct(a[now].ch[0]);
    }
    if (!a[now].occ)insert(rt,k,now);
    else insert(rt,k,++tot);
}
int find_rank(int rt,int x)//这里的findrank实际上是在splay里找比x小的数的数量
{
    int now=roots[rt],ans=0;
    while (now)
    {
        if (a[now].data>x) now=a[now].ch[0];
        else if (a[now].data<x)
            ans+=(a[now].occ+a[a[now].ch[0]].siz),
            now=a[now].ch[1];
        else {ans+=a[a[now].ch[0]].siz;break;}
    }
    return ans; 
}
void build(int now,int begin,int end)
{
    for (int i=begin;i<=end;i++) insert(now,num[i],++tot);
    if (begin==end) return;
    int mid=(begin+end)>>1;
    build(now<<1,begin,mid);
    build(now<<1|1,mid+1,end);
}
int solve1(int now,int begin,int end,int l,int r,int k)
{
    if (l<=begin&&end<=r) return find_rank(now,k);
    int mid=(begin+end)>>1,rank=0;
    if (mid>=l) rank+=solve1(now<<1,begin,mid,l,r,k);
    if (mid<r) rank+=solve1(now<<1|1,mid+1,end,l,r,k);
    return rank;
}
int solve2(int l,int r,int k)
{
    int begin=0,end=1e8+1,mid;
    while (begin<end)
    {
        mid=(begin+end)>>1;
        if (solve1(1,1,n,l,r,mid)<k)
            begin=mid+1;
        else end=mid;
    }
    return begin-1;
}
void solve3(int now,int begin,int end,int pos,int k)
{
    replace(now,num[pos],k);
    if (begin==end) {num[pos]=k;return;}
    int mid=(begin+end)>>1;
    if (mid>=pos) solve3(now<<1,begin,mid,pos,k);
    else solve3(now<<1|1,mid+1,end,pos,k);
}
int solve4(int now,int begin,int end,int l,int r,int k)
{
    if (l<=begin&&end<=r) return find_next_min(now,k);
    int mid=(begin+end)>>1,ans=-0x7fffffff;
    if (mid>=l) ans=max(ans,solve4(now<<1,begin,mid,l,r,k));
    if (mid<r) ans=max(ans,solve4(now<<1|1,mid+1,end,l,r,k));
    return ans;
}
int solve5(int now,int begin,int end,int l,int r,int k)
{
    if (l<=begin&&end<=r) return find_next_max(now,k);
    int mid=(begin+end)>>1,ans=0x7fffffff;
    if (mid>=l) ans=min(ans,solve5(now<<1,begin,mid,l,r,k));
    if (mid<r) ans=min(ans,solve5(now<<1|1,mid+1,end,l,r,k));
    return ans;
}
main()
{
    n=in();m=in();
    int opt,x,y,k;
    for (int i=1;i<=n;i++) num[i]=in();
    build(1,1,n);
    while (m--)
    {
        opt=in();
        if (opt!=3)x=in(),y=in(),k=in();
        else x=in(),y=in();
        if (opt==1) printf("%d\n",solve1(1,1,n,x,y,k)+1);
        else if (opt==2) printf("%d\n",solve2(x,y,k));
        else if (opt==3) solve3(1,1,n,x,y);
        else if (opt==4) printf("%d\n",solve4(1,1,n,x,y,k));
        else printf("%d\n",solve5(1,1,n,x,y,k));
    }
}

你可能感兴趣的:(【BZOJ3196】【Tyvj1730】二逼平衡树,第一次的树套树(线段树+splay))