线段树套平衡树 可修改的区间第K小问题

在没有修改操作时,应用划分树可以在O(MlogN)时间内解决查找区间第K小的问题,但是在引入修改(将原序列中的某个值改为另一个值)之后,划分树就不行了。
这时,需要数据结构联合的思想。
可以观察一下:
(1)区间操作:使用线段树;
(2)修改值(其实是先删除再插入)和找第K小:使用平衡树;
现在这两种操作都有,应该使用 线段树+平衡树
准确来说是线段树套平衡树,即对原序列建立一棵线段树,其中的每个结点内套一棵对该结点管辖区间内的平衡树。

<1>结点类型(结构):
struct  seg_node {
    
int  l, r, mid, lch, rch, rt;
} T0[MAXN0];
struct  SBT_node {
    
int  v, l, r, p, sz0, sz, mul;
} T[MAXN];
其中seg_node是线段树结点类型,SBT_node是平衡树(SBT)结点类型。需要注意的是seg_node里面的rt域(root的缩写),它是该结点内套的平衡树的根结点下标索引(因为对于任意一棵平衡树,只要知道了其根结点就可以遍历整棵树)。

<2>建树:
建树是线段树和平衡树一起建。在建立线段树结点的时候,先建立一棵空的平衡树(rt域置0),然后再在平衡树里面逐个插入该结点管辖区间内的所有元素即可;

<3>修改:
修改操作要注意:如果要将A[x](A为原序列)的值修改为y,则需要自顶向下遍历整棵线段树,将所有包含了A[x]的结点内的平衡树全部执行“删除v=A[x](这个可以通过真正维护一个序列得到),再插入y”的操作;

<4>找区间第K小:
这个操作极其麻烦。需要借助二分。
设要在区间[l, r]中找到第K小。首先将[l, r]拆分成若干个线段树结点,然后二分一个值x,在这些结点的平衡树中找到x的rank(这里的rank指平衡树中有多少个值比x小,不需要加1),加起来,最后再加1,就是x在[l, r]中的总名次。问题是,设[l..r]中第K小的数为v1,第(K+1)小的数为v2(如果不存在的话,v2=+∞),则[v1, v2)内的数都是“第K小”的。因此,不能二分数字,而应该二分元素。设S[i]为原序列中第i小的数,二分i,然后在根结点的平衡树中找到第i小的即为S[i],再求其名次,这样直到找到总名次为K的元素为止。问题还没完,序列中可能有元素的值相同,这时可能永远也找不到第K小的(比如序列1 2 3 3 3 4 5,K=4,若“序列中比x小的元素总数+1”为x的名次,则永远也找不到第4小的),因此,若这样求出的“名次”小于等于K,都应该将下一次的左边界设为mid而不是(mid+1),而“名次”大于K时,该元素肯定不是第K小的,所以下一次右边界设为(mid-1)。

代码(本机测最猥琐数据4s以内,交到ZJU上TLE,不知为什么,神犇指点一下,3x):
#include  < iostream >
#include 
< stdio.h >
using   namespace  std;
#define  re(i, n) for (int i=0; i<n; i++)
#define  re3(i, l, r) for (int i=l; i<=r; i++)
const   int  MAXN0  =   110000 , MAXN  =   930000 , INF  =   ~ 0U   >>   2 ;
struct  seg_node {
    
int  l, r, mid, lch, rch, rt;
} T0[MAXN0];
struct  SBT_node {
    
int  v, l, r, p, sz0, sz, mul;
} T[MAXN];
int  No0, No, n, root, rt0, a[MAXN0  >>   1 ], b[MAXN0  >>   1 ], l1, r1, len;
void  slc( int  _p,  int  _c)
{
    T[_p].l 
=  _c; T[_c].p  =  _p;
}
void  src( int  _p,  int  _c)
{
    T[_p].r 
=  _c; T[_c].p  =  _p;
}
void  upd( int  x)
{
    T[x].sz0 
=  T[T[x].l].sz0  +  T[T[x].r].sz0  +  T[x].mul;
    T[x].sz 
=  T[T[x].l].sz  +  T[T[x].r].sz  +   1 ;
}
void  lrot( int  x)
{
    
int  y  =  T[x].p;  if  (y  ==  rt0) T[rt0  =  x].p  =   0 else  { int  p  =  T[y].p;  if  (y  ==  T[p].l) slc(p, x);  else  src(p, x);}
    src(y, T[x].l); slc(x, y); T[x].sz0 
=  T[y].sz0; T[x].sz  =  T[y].sz; upd(y);
}
void  rrot( int  x)
{
    
int  y  =  T[x].p;  if  (y  ==  rt0) T[rt0  =  x].p  =   0 else  { int  p  =  T[y].p;  if  (y  ==  T[p].l) slc(p, x);  else  src(p, x);}
    slc(y, T[x].r); src(x, y); T[x].sz0 
=  T[y].sz0; T[x].sz  =  T[y].sz; upd(y);
}
void  maintain( int  x,  bool  ff)
{
    
int  z;
    
if  (ff) {
        
if  (T[T[T[x].r].r].sz  >  T[T[x].l].sz) {z  =  T[x].r; lrot(z);}
        
else   if  (T[T[T[x].r].l].sz  >  T[T[x].l].sz) {z  =  T[T[x].r].l; rrot(z); lrot(z);}  else   return ;
    } 
else  {
        
if  (T[T[T[x].l].l].sz  >  T[T[x].r].sz) {z  =  T[x].l; rrot(z);}
        
else   if  (T[T[T[x].l].r].sz  >  T[T[x].r].sz) {z  =  T[T[x].l].r; lrot(z); rrot(z);}  else   return ;
    }
    maintain(T[z].l, 
0 ); maintain(T[z].r,  1 ); maintain(z,  0 ); maintain(z,  1 );
}
int  find( int  _v)
{
    
int  i  =  rt0, v0;
    
while  (i) {
        v0 
=  T[i].v;
        
if  (_v  ==  v0)  return  i;  else   if  (_v  <  v0) i  =  T[i].l;  else  i  =  T[i].r;
    }
    
return   0 ;
}
void  ins( int  _v)
{
    
if  ( ! rt0) {
        T[
++ No].v  =  _v; T[No].l  =  T[No].r  =  T[No].p  =   0 ; T[No].sz0  =  T[No].sz  =  T[No].mul  =   1 ; rt0  =  No;
    } 
else  {
        
int  i  =  rt0, j, v0;
        
while  ( 1 ) {
            T[i].sz0
++ ; v0  =  T[i].v;
            
if  (_v  ==  v0) {T[i].mul ++ return ;}  else   if  (_v  <  v0) j  =  T[i].l;  else  j  =  T[i].r;
            
if  (j) i  =  j;  else   break ;
        }
        T[
++ No].v  =  _v; T[No].l  =  T[No].r  =   0 ; T[No].sz0  =  T[No].sz  =  T[No].mul  =   1 if  (_v  <  v0) slc(i, No);  else  src(i, No);
        
while  (i) {T[i].sz ++ ; maintain(i, _v  >  T[i].v); i  =  T[i].p;}
    }
}
void  del( int  x)
{
    
if  (T[x].mul  >   1 ) {
        T[x].mul
-- ;
        
while  (x) {T[x].sz0 -- ; x  =  T[x].p;}
    } 
else  {
        
int  l  =  T[x].l, r  =  T[x].r;
        
if  ( ! ||   ! r) {
            
if  (x  ==  rt0) T[rt0  =  l  +  r].p  =   0 else  {
                
int  p  =  T[x].p;  if  (x  ==  T[p].l) slc(p, l  +  r);  else  src(p, l  +  r);
                
while  (p) {T[p].sz0 -- ; T[p].sz -- ; p  =  T[p].p;}
            }
        } 
else  {
            
int  i  =  l, j;
            
while  (j  =  T[i].r) i  =  j;
            T[x].v 
=  T[i].v; T[x].mul  =  T[i].mul;  int  p  =  T[i].p;  if  (i  ==  T[p].l) slc(p, T[i].l);  else  src(p, T[i].l);
            
while  (p) {upd(p); p  =  T[p].p;}
        }
    }
}
int  Find_Kth( int  K)
{
    
int  i  =  rt0, s0, m0;
    
while  (i) {
        s0 
=  T[T[i].l].sz0; m0  =  T[i].mul;
        
if  (K  <=  s0) i  =  T[i].l;  else   if  (K  <=  s0  +  m0)  return  T[i].v;  else  {K  -=  s0  +  m0; i  =  T[i].r;}
    }
}
int  rank( int  _v)
{
    
int  i  =  rt0, tot  =   0 , v0;
    
while  (i) {
        v0 
=  T[i].v;
        
if  (_v  ==  v0) {tot  +=  T[T[i].l].sz0;  return  tot;}  else   if  (_v  <  v0) i  =  T[i].l;  else  {tot  +=  T[T[i].l].sz0  +  T[i].mul; i  =  T[i].r;}
    }
    
return  tot;
}
int  mkt( int  l,  int  r)
{
    T0[
++ No0].l  =  l; T0[No0].r  =  r;  int  mid  =  l  +  r  >>   1 ; T0[No0].mid  =  mid; rt0  =   0 ;
    re3(i, l, r) ins(a[i]); T0[No0].rt 
=  rt0;
    
if  (l  <  r) { int  No00  =  No0; T0[No00].lch  =  mkt(l, mid); T0[No00].rch  =  mkt(mid  +   1 , r);  return  No00;}  else  {T0[No0].lch  =  T0[No0].rch  =   0 return  No0;}
}
void  fs( int  x)
{
    
if  (x) {
        
int  l0  =  T0[x].l, r0  =  T0[x].r;
        
if  (l0  >=  l1  &&  r0  <=  r1) b[len ++ =  T0[x].rt;  else   if  (l0  >  r1  ||  r0  <  l1)  return else  {fs(T0[x].lch); fs(T0[x].rch);}
    }
}
void  C( int  x,  int  _v)
{
    
int  i  =  root, l0, r0, mid0, v0  =  a[x], N;
    
while  (i) {
        l0 
=  T0[i].l; r0  =  T0[i].r; mid0  =  T0[i].mid; rt0  =  T0[i].rt;
        N 
=  find(v0); del(N); ins(_v); T0[i].rt  =  rt0;
        
if  (x  <=  mid0) i  =  T0[i].lch;  else  i  =  T0[i].rch;
    }
    a[x] 
=  _v;
}
int  Q( int  K)
{
    len 
=   0 ; fs(root);
    
int  ls  =   1 , rs  =  n, mids, midv, tot;
    
while  (ls  <  rs) {
        mids 
=  ls  +  rs  +   1   >>   1 ; rt0  =  T0[root].rt; midv  =  Find_Kth(mids);
        tot 
=   1 ; re(i, len) {rt0  =  b[i]; tot  +=  rank(midv);}
        
if  (tot  <=  K) ls  =  mids;  else  rs  =  mids  -   1 ;
    }
    rt0 
=  T0[root].rt;  return  Find_Kth(ls);
}
int  main()
{
    
int  tests, m, x, y, K;
    
char  ch;
    scanf(
" %d " & tests);
    re(testno, tests) {
        scanf(
" %d%d " & n,  & m); No0  =  No  =   0 ;
        re(i, n) scanf(
" %d " & a[i]); ch  =  getchar();
        root 
=  mkt( 0 , n  -   1 );
        re(i, m) {
            ch 
=  getchar();
            
if  (ch  ==   ' C ' ) {
                scanf(
" %d%d%*c " & x,  & y);
                C(
-- x, y);
            } 
else  {
                scanf(
" %d%d%d%*c " & l1,  & r1,  & K);
                l1
-- ; r1 -- ; printf( " %d\n " , Q(K));
            }
        }
    }
    
return   0 ;
}

你可能感兴趣的:(线段树套平衡树 可修改的区间第K小问题)