题意
我用了线段树套平衡树(Splay)去做。
分别说每一问:
第一问:求一个数区间排名:
我们求出区间内小于这个数的个数,加\(1\)即可。
第二问:求区间第K大:
我们显然不能在\(\log n\)个平衡树上求这个东西,于是我们在外面二分答案\(mid\),之后判断其排名与\(k\)的关系即可。
第三问:单点修改:
我们将从\(1\)到叶子的每个点的平衡树都删除原数,加入新数即可。
第四/五问:求区间某个数的前驱和后继:
这两问是相同的,我们只要在\(\log n\)个平衡树上分别求一遍,最后取\(\max\)或\(\min\)即可。
code:
#include
using namespace std;
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
const int maxn=4*1e6+10;
const int inf=2147483647;
int n,m;
int a[maxn],root[maxn];
inline int read()
{
char c=getchar();int res=0,f=1;
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')res=res*10+c-'0',c=getchar();
return res*f;
}
namespace Splay
{
int tot;
int fa[maxn],val[maxn],size[maxn],cnt[maxn];
int ch[maxn][2];
inline void clear(int x){fa[x]=size[x]=cnt[x]=val[x]=ch[x][0]=ch[x][1]=0;}
inline int get(int x){return ch[fa[x]][1]==x;}
inline void up(int x){size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];}
inline void rotate(int x)
{
int y=fa[x],z=fa[y],k=get(x),w=ch[x][k^1];
if(z)ch[z][get(y)]=x;ch[x][k^1]=y;ch[y][k]=w;
if(w)fa[w]=y;fa[y]=x;fa[x]=z;
up(y),up(x);
}
inline void splay(int p,int x,int goal=0)
{
while(fa[x]!=goal)
{
int y=fa[x];
if(fa[y]!=goal)rotate(get(x)==get(y)?y:x);
rotate(x);
}
if(!goal)root[p]=x;
}
inline void insert(int p,int x)
{
if(!root[p])
{
root[p]=++tot;
val[tot]=x;size[tot]=cnt[tot]=1;
return;
}
int now=root[p],last=0;
while(now&&val[now]!=x)last=now,now=ch[now][val[now]x)now=ch[now][0];
else if(val[now]1){cnt[now]--;up(now);return;}
if(!ch[now][0]&&!ch[now][1]){clear(now);root[p]=0;return;}
if(!ch[now][0])
{
int x=ch[now][1];
root[p]=x;fa[x]=0;
return;
}
if(!ch[now][1])
{
int x=ch[now][0];
root[p]=x;fa[x]=0;
return;
}
int x=Pre(p);
splay(p,x);
ch[x][1]=ch[now][1],fa[ch[now][1]]=x;
clear(now); up(x);
}
inline int pre(int p,int k)
{
int now=root[p],res=-inf;
while(now)
{
if(val[now]k)
{
res=min(res,val[now]);
now=ch[now][0];
}
else now=ch[now][1];
}
return res;
}
}
void build(int p,int l,int r)
{
for(int i=l;i<=r;i++)Splay::insert(p,a[i]);
if(l==r)return;
int mid=(l+r)>>1;
build(ls(p),l,mid);build(rs(p),mid+1,r);
}
void change(int p,int l,int r,int pos,int k)
{
Splay::del(p,a[pos]);Splay::insert(p,k);
if(l==r){a[pos]=k;return;}
int mid=(l+r)>>1;
if(pos<=mid)change(ls(p),l,mid,pos,k);
else change(rs(p),mid+1,r,pos,k);
}
int queryrk(int p,int l,int r,int ql,int qr,int k)
{
if(l>=ql&&r<=qr)return Splay::getrk(p,k);
int mid=(l+r)>>1,res=0;
if(ql<=mid)res+=queryrk(ls(p),l,mid,ql,qr,k);
if(qr>mid)res+=queryrk(rs(p),mid+1,r,ql,qr,k);
return res;
}
int querypre(int p,int l,int r,int ql,int qr,int k)
{
if(l>=ql&&r<=qr)return Splay::pre(p,k);
int mid=(l+r)>>1,res=-inf;
if(ql<=mid)res=max(res,querypre(ls(p),l,mid,ql,qr,k));
if(qr>mid)res=max(res,querypre(rs(p),mid+1,r,ql,qr,k));
return res;
}
int querynxt(int p,int l,int r,int ql,int qr,int k)
{
if(l>=ql&&r<=qr)return Splay::nxt(p,k);
int mid=(l+r)>>1,res=inf;
if(ql<=mid)res=min(res,querynxt(ls(p),l,mid,ql,qr,k));
if(qr>mid)res=min(res,querynxt(rs(p),mid+1,r,ql,qr,k));
return res;
}
int main()
{
n=read(),m=read();
for(int i=1;i<=n;i++)a[i]=read();
build(1,1,n);
while(m--)
{
int op=read();
if(op==1)
{
int l=read(),r=read(),k=read();
printf("%d\n",queryrk(1,1,n,l,r,k)+1);
}
if(op==2)
{
int ql=read(),qr=read(),k=read();
int l=0,r=1e8,res=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(queryrk(1,1,n,ql,qr,mid)