题意:您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
题解:树套树,外层是一棵线段树,每个节点下有一棵平衡树(平衡树记录ls,rs,因此记录根节点就可以遍历整棵树),先不考虑空间问题,ask(l,r)可以分成多个线段树区间,每个区间下有平衡树可以查询排名,1&3-5操作都很简单地能够实现,查询排名需要二分答案以及一些特殊技巧,我的做法是,二分出来可能的最小值,然后取它的前驱的后继。不知道是否有非法操作(比如:1 3 3 3 4 xth 3)又加了一些特殊技巧,都是细节。
空间问题之前一直搞不懂,以为线段树有4*N个节点,每个节点要一棵N个节点的平衡树,空间会超。其实不是的。线段树的每层节点一共需要N个节点的线段树,一共logN层,加在一起的平衡树节点个数为NlogN。加上修改操作,空间复杂度是O((M+N)logN+4*N)
时间复杂度是O(NlogNlogNlogMAXai)
这道题还发现了一个我写线段树的毛病,算是小瑕疵吧。警戒一下自己:线段树写的丑 别忘处理边界问题!if(k>=4*N) return;
还有一种做法是主席树?会快一点?
#include<iostream>
#include<cstdio>
#include<cstdlib>
#define N 50005
#define NN 1600000 //(M+N)logN
#define inf 1<<30
#define pa pair<int,int>
#define mp make_pair
using namespace std;
struct seg{int l,r,rt;}st[N*4];
struct treap{int ls,rs,val,key,sz;}T[NN];
int n,m,num=0,a[N],ret_getrank,ret_getpre,ret_getsuc;
void upd(int x){T[x].sz=T[T[x].ls].sz+T[T[x].rs].sz+1;}
int merge(int a,int b){
if(a==0||b==0) return a+b;
if(T[a].key<T[b].key) {T[a].rs=merge(T[a].rs,b),upd(a);return a;}
else {T[b].ls=merge(a,T[b].ls),upd(b);return b;}
}
pa split(int a,int k){
pa tmp;
if(k==0) return mp(0,a);
int ls=T[a].ls,rs=T[a].rs;
if(T[ls].sz==k) {T[a].ls=0;upd(a);return mp(ls,a);}
if(T[ls].sz+1==k) {T[a].rs=0;upd(a);return mp(a,rs);}
if(T[ls].sz>k) {tmp=split(ls,k);T[a].ls=tmp.second;upd(a);return mp(tmp.first,a);}
if(T[ls].sz+1<k) {tmp=split(rs,k-T[ls].sz-1);T[a].rs=tmp.first;upd(a);return mp(a,tmp.second);}
}
pa rank(int k,int x){
int tmp=inf,ans=0;
while(k){
if(T[k].val==x) tmp=min(tmp,ans+T[T[k].ls].sz+1),k=T[k].ls;
else if(T[k].val<x) ans+=T[T[k].ls].sz+1,k=T[k].rs;
else if(T[k].val>x) k=T[k].ls;
}
return tmp==inf?mp(ans,0):mp(tmp,1);
}
int pre(int k,int x){
int ans=-inf;
while(k){
if(T[k].val<x) ans=max(ans,T[k].val),k=T[k].rs;
else k=T[k].ls;
} return ans;
}
int suc(int k,int x){
int ans=inf;
while(k){
if(T[k].val>x) ans=min(ans,T[k].val),k=T[k].ls;
else k=T[k].rs;
} return ans;
}
void ins(int &r,int x){ //st[r].rt = root
int k=rank(r,x).first;
T[++num].ls=0;T[num].rs=0;T[num].sz=1;
T[num].val=x;T[num].key=rand();
pa tmp=split(r,k);
r=merge(tmp.first,num);
r=merge(r,tmp.second);
}
void del(int &r,int x){
int k=rank(r,x).first;
pa tmp1=split(r,k);
pa tmp2=split(tmp1.first,k-1);
r=merge(tmp2.first,tmp1.second);
}
void getrank(int k,int a,int b,int c){
if(k>=4*N) return;
int l=st[k].l,r=st[k].r;
if(r<a || b<l) return;
if(a<=l && r<=b) {pa tmp=rank(st[k].rt,c);ret_getrank+=tmp.first-tmp.second;return;}
getrank(k<<1,a,b,c);getrank(k<<1|1,a,b,c);
}
void getpre(int k,int a,int b,int c){
if(k>=4*N) return;
int l=st[k].l,r=st[k].r;
if(r<a || b<l) return;
if(a<=l && r<=b) {ret_getpre=max(ret_getpre,pre(st[k].rt,c));return;}
getpre(k<<1,a,b,c);getpre(k<<1|1,a,b,c);
}
void getsuc(int k,int a,int b,int c){
if(k>=4*N) return;
int l=st[k].l,r=st[k].r;
if(r<a || b<l) return;
if(a<=l && r<=b) {ret_getsuc=min(ret_getsuc,suc(st[k].rt,c));return;}
getsuc(k<<1,a,b,c);getsuc(k<<1|1,a,b,c);
}
int getxth(int a,int b,int c){
int ll=0,rr=inf;
while(ll<rr){
int mid=(rr+ll)>>1;
ret_getrank=0;
getrank(1,a,b,mid);
if(ret_getrank+1>=c) rr=mid;
else ll=mid+1;
}
ret_getrank=0;
getrank(1,a,b,ll);
if(ret_getrank+1!=c) {ret_getpre=0;getpre(1,a,b,ll);return ret_getpre;}
ret_getpre=0;ret_getsuc=inf;
getpre(1,a,b,ll);
getsuc(1,a,b,ret_getpre);
return ret_getsuc;
}
void modify(int k,int t,int x){
if(k>=4*N) return;
int l=st[k].l,r=st[k].r;
if(t<l||r<t) return;
if(l<=t&&t<=r) {del(st[k].rt,a[t]);ins(st[k].rt,x);}
modify(k<<1,t,x);modify(k<<1|1,t,x);
}
void build(int k,int l,int r){
st[k].l=l;st[k].r=r;
for(int i=l;i<=r;i++) ins(st[k].rt,a[i]);
if(l==r) return;
int mid=(l+r)>>1;
build(k<<1,l,mid);build(k<<1|1,mid+1,r);
}
int main(){
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
build(1,1,n);
while(m--){
int opt,x,y,c;scanf("%d",&opt);ret_getrank=0;ret_getpre=0;ret_getsuc=inf;
if(opt==1) scanf("%d%d%d",&x,&y,&c),getrank(1,x,y,c),printf("%d\n",ret_getrank+1);
if(opt==2) scanf("%d%d%d",&x,&y,&c),printf("%d\n",getxth(x,y,c));
if(opt==3) scanf("%d%d",&x,&y),modify(1,x,y),a[x]=y;
if(opt==4) {scanf("%d%d%d",&x,&y,&c);getpre(1,x,y,c);ret_getpre=ret_getpre==-inf?0:ret_getpre;printf("%d\n",ret_getpre);}
if(opt==5) {scanf("%d%d%d",&x,&y,&c);getsuc(1,x,y,c);ret_getsuc=ret_getsuc==inf?0:ret_getsuc;printf("%d\n",ret_getsuc);}
}
return 0;
}