您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
对于操作1,2,4,5各输出一行,表示查询结果
1.n和m的数据范围:n,m<=50000
2.序列中每个数的数据范围:[0,1e8]
题解:线段树套splay
先建立一棵区间线段树,然后线段树中的每个节点建立一棵位置在当前点表示的区间的权值splay(小的在左儿子,大的在右儿子)
solve1: 直接把所有在范围内的区间内比他小的数统计一下,然后+1
solve2: 二分答案,通过solve1,计算mid在区间中的排名
solve3: 把这个位置原本的数从所有包含这个位置的区间中删去,然后加入新的值
solve4:从所有在范围内的区间(给出的范围可能在线段树中跨越多个区间)中找前驱最大的
solve5:从所有在范围内的区间(给出的范围可能在线段树中跨越多个区间)中找后继最小的
思路非常清晰明了,但是实现起来异常的麻烦,各种手残简直鬼畜。
刚开始姿势不够优越,TLE。改了姿势后刚好过了,"9724 ms"。。。。。
<span style="font-size:18px;">#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 2000003 #define M 50003 using namespace std; int n,m; int ls[4*M],rs[4*M],root[4*M],pd; int ch[N][3],fa[N],maxn,a[N],size[N],key[N],cnt[N],sz; void clear(int x) { size[x]=key[x]=cnt[x]=ch[x][1]=ch[x][0]=fa[x]=0; } int get(int x) { return ch[fa[x]][1]==x; } void update(int x) { size[x]=cnt[x]; if (ch[x][0]) size[x]+=size[ch[x][0]]; if (ch[x][1]) size[x]+=size[ch[x][1]]; } void rotate(int x) { int y=fa[x]; int z=fa[y]; int which=get(x); if (z) ch[z][ch[z][1]==y]=x; fa[x]=z; ch[y][which]=ch[x][which^1]; fa[ch[y][which]]=y; ch[x][which^1]=y; fa[y]=x; update(y); update(x); } void splay(int x,int &root) { for (int f;(f=fa[x]);rotate(x)) if (fa[f]) rotate(get(x)==get(f)?f:x); root=x; } void insert(int &root,int x) { if (!root) { root=++sz; clear(sz); size[sz]=cnt[sz]=1; key[sz]=x; return; } int f=0; int now=root; while(true) { if (x==key[now]) { cnt[now]++; update(now); splay(now,root); return; } f=now; now=ch[now][key[now]<x]; if (!now) { sz++; clear(sz); key[sz]=x; cnt[sz]=size[sz]=1; fa[sz]=f; ch[f][key[f]<x]=sz; update(f); splay(sz,root); return; } } } int find(int x,int &root) //查找x的位置 { int now=root; while (true) { if (now==0) return 0; if (key[now]==x) { splay(now,root); return now; } if (x<key[now]) now=ch[now][0]; if (x>key[now]) now=ch[now][1]; } } int findx(int x,int &root) //查找当前区间内比x小的数有多少个 { int now=root; int ans=0; while (true) { if (!now) return ans; if (x<key[now]) now=ch[now][0]; else { ans+=(ch[now][0]?size[ch[now][0]]:0); if (x==key[now]) { splay(now,root); pd=true; return ans; } ans+=cnt[now]; now=ch[now][1]; } } } int pre(int root) { int now=ch[root][0]; while (ch[now][1]) now=ch[now][1]; return now; } int next(int root) { int now=ch[root][1]; while (ch[now][0]) now=ch[now][0]; return now; } void del(int &root,int x) { splay(x,root); if (cnt[root]>1) { cnt[root]--; update(root); return; } if (!ch[root][1]&&!ch[root][0]) { clear(root); root=0; return ; } if (!ch[root][1]) { int old=root; root=ch[root][0]; fa[root]=0; clear(old); return; } if (!ch[root][0]) { int old=root; root=ch[root][1]; fa[root]=0; clear(old); return; } int k=pre(root); int old=root; splay(k,root); ch[k][1]=ch[old][1]; fa[ch[k][1]]=k; clear(old); update(k); return; } void pointchange(int now,int l,int r,int x,int v) { insert(root[now],v); if (l==r) return; int mid=(l+r)/2; if (x<=mid) pointchange(now<<1,l,mid,x,v); else pointchange(now<<1|1,mid+1,r,x,v); } int solve1(int now,int l,int r,int ll,int rr,int k) { if (l>=ll&&r<=rr) { return findx(k,root[now]); } int mid=(l+r)/2; int ans=0; if (ll<=mid) ans+=solve1(now<<1,l,mid,ll,rr,k); if (rr>mid) ans+=solve1(now<<1|1,mid+1,r,ll,rr,k); return ans; } void solve2(int l,int r,int k) { int head=0,tail=maxn+1,mid; while (head<tail) { mid=(head+tail)/2; if (solve1(1,1,n,l,r,mid)<k) head=mid+1; else tail=mid; } printf("%d\n",head-1); } void solve3(int now,int l,int r,int x,int v) { int t=find(a[x],root[now]); del(root[now],t); insert(root[now],v); if (l==r) return; int mid=(l+r)/2; if (x<=mid) solve3(now<<1,l,mid,x,v); else solve3(now<<1|1,mid+1,r,x,v); } int find_next_min(int rt,int x) { int now=root[rt],t=0,ans=-1; while (now) { if (key[now]<x) { if (ans<key[now])ans=key[now]; now=ch[now][1]; } else now=ch[now][0]; } return ans; } int find_next_max(int rt,int x) { int now=root[rt],t=0,ans=1000000000; while (now) { if (key[now]>x) { if (ans>key[now]) ans=key[now]; now=ch[now][0]; } else now=ch[now][1]; } return ans; } int solve4(int now,int l,int r,int ll,int rr,int x) { if (l>=ll&&r<=rr) { return find_next_min(now,x); } int mid=(l+r)/2; int maxn=0; if (ll<=mid) maxn=max(maxn,solve4(now<<1,l,mid,ll,rr,x)); if (rr>mid) maxn=max(maxn,solve4(now<<1|1,mid+1,r,ll,rr,x)); return maxn; } int solve5(int now,int l,int r,int ll,int rr,int x) { if (l>=ll&&r<=rr) { return find_next_max(now,x); } int mid=(l+r)/2; int minn=1000000000; if (ll<=mid) minn=min(minn,solve5(now<<1,l,mid,ll,rr,x)); if (rr>mid) minn=min(minn,solve5(now<<1|1,mid+1,r,ll,rr,x)); return minn; } int main() { freopen("input.txt","r",stdin); freopen("my.out","w",stdout); scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&a[i]),maxn=max(maxn,a[i]); for (int i=1;i<=n;i++) pointchange(1,1,n,i,a[i]); for (int i=1;i<=m;i++) { int op,x,y,k; scanf("%d%d%d",&op,&x,&y); if (op!=3) scanf("%d",&k); switch(op) { case 1: printf("%d\n",solve1(1,1,n,x,y,k)+1); break; //注意rank是比他小的个数+1 case 2: solve2(x,y,k); break; case 3: solve3(1,1,n,x,y); maxn=max(maxn,y); a[x]=y; break; case 4: printf("%d\n",solve4(1,1,n,x,y,k)); break; case 5: printf("%d\n",solve5(1,1,n,x,y,k)); break; } } } </span>