用二叉树来理解树状数组

 

树状数组(Fenwick tree,又名binary indexed tree),是一种很实用的数据结构。它通过用节点i,记录数组下标在[ i –2^k + 1, i]这段区间的所有数的信息(其中,ki的二进制表示中末尾0的个数,设lowbit(i) = 2^k),实现在O(lg n) 时间内对数组数据的查找和更新。

树状数组的传统解释图,不能很直观的看出其所能进行的更新和查询操作。其最主要的操作函数lowbit(k)与数的二进制表示相关,本质上仍是一种二分。因而可以通过二叉树,对其进行分析。事实上,从二叉树图,我们对它所能进行的操作和不能进行的操作一目了然。

和前面提到的点树类似,先画一棵二叉树,然后对节点中序遍历(点树是采用广度优先),每个节点仍然只记录左子树信息,见图:

 

用二叉树来理解树状数组 

 

由于采用的是中序遍历,从节点1到节点k时,刚好有k个叶子被统计。

可以证明:

  叶子k,一定在节点k子树下。

  以节点k为根的树,其子树共有叶子lowbit(k)

节点k的父节点是:k + lowbit(k) k - lowbit(k) 

节点k + lowbit(k) 是节点k的最近父节点,且节点k在它的子树下。

节点k - lowbit(k) 是节点k的最近父节点,且节点k在它的子树下。

节点k,统计的叶子范围为:(k - lowbit(k),  k]

节点k的左孩子是:k - lowbit(k) / 2

 

下面分析树状数组两面主要应用:

1 更新数据x,进行区间查询。

2 更新区间,查询某个数。

由于,树状数组只统计了左子树的信息,因而只能查询更新区间[1, x]。只在在满足[x,y]的信息可以由[1,x-1][1,y]的信息推导出时,才能进行区间[x,y]的查询更新。这也是树状数组不能用于任意区间求最值的根本原因。

 

先定义两个集合:

up_right(k) 节点k所有的父节点,且节点k在它们的子树下。

up_left(k)   节点k所有的父节点,且节点k在它们的子树下。

 

1  更新数据x,查询区间[1,y]

显然,更新叶子x,要找出叶子x在哪些节点的子树下。因而节点k、所有的up_right(k)

都要更新。

查询[1, y],实际上就是把该区间拆分成一系列小区间,并找出统计这些区间的节点。可以通过找出y在哪些节点的子树下,这些节点恰好不重复的统计了区间[1, y-1]。因而要访问节点y、所有的up_left(y)

 

2 更新区间[1,y],查询数据x

  这和前面的操作恰好相反。与前面的最大不同之处在于:节点保存的不再是其叶子总个数这些信息,而是该区间的所有叶子都改变了多少。也就是说:每个叶子的信息,分散到了所有对它统计的节点上。因此操作和前面相似:

  更新[1,y]时,更新节点y、所有up_left(y)

  查询x时,  访问x、所有up_right(x)

 

前面的树状数组,只对左子树信息进行统计,如果从后往前读数据初始化树状数组,则变成只对右子树信息进行统计,这时更新和查询操作,刚好和前面的相反。

 

一般情况下,树状数组比点树省空间,对区间[1, M]只要M+1空间,查询更新时定位节点比较快,定位父节点和左右孩子相对麻烦点(不过,一般也不用到。从上往下查找,可参考下面代码中的erease_nth函数(删除第n小的数))。

 

下面是使用树状数组的实现代码(求逆序数和模拟约瑟夫环问题):

 

 

树状数组
// www.cnblogs.com/flyinghearts
#include < cstdio >  
#include
< cstring >  
#include
< cassert >  
 
template
< int  N >   struct  Round2k 
enum  { down  =  Round2k < /   2u > ::down  *   2 }; };

template
<>   struct  Round2k < 1 >  {  enum  { down  =   1 }; };
 

template 
< int   Total, typename T  =   int >    // 区间[1, Total]
class   BIT {
  
enum  { Min2k  =  Round2k < Total > ::down};  
  T info[Total 
+   1 ];                
  T sz;                                 
// 可以用info[0]储存总大小
  
public :
  BIT() { clear(); }
  
void  clear() { memset( this 0 sizeof ( * this ));}
  
int  size() {  return  sz; }

  
int  lowbit( int  idx) {  return  idx  &   - idx;}
  
// 寻找最近的父节点,left_up/right_up 分别使得idx在其右/左子树下
   void  left_up( int &   idx) { idx  -=  lowbit(idx); }
  
void  right_up( int &   idx) { idx  +=  lowbit(idx); }

  
void  update( int  idx , const   int  val  =   1 ) {    // 叶子idx 改变val个  
    assert(idx  >   0 );
    sz 
+=  val;
    
for  (; idx  <=  Total; right_up(idx)) info[idx]  +=  val; 
  }

  
void  init( int  arr[],  int  n) {                //  arr[i]为叶子i+1的个数
    assert(n  <=  Total);
    sz 
=  n;
    
//  for (int i = 0; i < n; ) {
      
//  info[i + 1] = arr[i];
      
//  if (++i >= n) break;
      
//  info[i + 1] = arr[i];
      
//  ++i;
      
//  for (int j = 1; j < lowbit(i); j *= 2u) info[i] += info[i - j];
    
//  }  
     for  ( int  i  =   0 ; i  <  n; ) {
      info[i 
+   1 =  arr[i];
      
if  ( ++ >=  n)  break ;
      
int  sum  =  arr[i];
      
int  pr  =   ++ i;
      left_up(pr);
      
for  ( int  j  =  i  -   1 ; j  >  pr; left_up(j)) sum  +=  info[j];
      info[i] 
=  sum;  
    }
  }
  
  
int  count( int  idx) {   // [1,idx] - [1, idx-1]
    assert(idx  >   0 );      
    
int  sum  =  info[idx];
    
//  int pr = idx;    // int pr = idx - lowbit(idx);    
    
//  left_up(pr);   
    
//  for (--idx; idx > pr; left_up(idx)) sum -= info[idx];  //
    
//  return sum;
     for  ( int  j  =   1 ; j  <  lowbit(idx); j  *=   2u ) sum  -=  info[idx  -  j];
    
return  sum;
  }  
  
  
int  lteq( int  idx) {                                   // 小等于
    assert(idx  >=   1   &&  idx  <=  Total);
      
int  sum  =   0 ;
    
for  (; idx  >   0 ; left_up(idx)) sum  +=  info[idx];
      
return  sum;
  }
  
  
int  gt( int  idx) {  return  sz  -  lteq(idx); }            // 大于

  
int   operator []( int  n)  {  return  erase_nth(n,  0 ); }   // 第n小
  
  
int  erase_nth( int  n,  const   bool  erase_flag  =   true )    // 删除第n小的数
  {
    assert(n 
>= 1   &&  n  <=  sz);
    sz 
-=  erase_flag;
    
int  idx  =  Min2k;                                // 从上往下搜索,先定位根节点 
     for  ( int  k  =  idx  /   2u ; k  >   0 ; k  /=   2u ) {
      
int  t  =  info[idx];
      
if  (n  <=  info[idx]) { info[idx]  -=  erase_flag; idx  -=  k;}   // 进入左子树      
       else  {
        n 
-=  t;
        
if  (Total  !=  Min2k  &&  Total  !=  Min2k  -   1 // 若不是完全二叉树
           while  (idx  +  k  >  Total)  k  /=   2u ;        // 则必须计算右孩子的编号 
        idx  +=  k;                                   // 进入右子树   
      }
    }
    assert(idx 
%   2u );                    // 最底层节点m一定是奇数,有两个叶子m,m+1
     if  (n  >  info[idx])  return  idx  +   1 ;   // 节点m+1前面已经更新过
    info[idx]  -=  erase_flag; 
    
return  idx;
  }

  
void  show()
  {
    
for  ( int  i  =   1 ; i  <=  Total;  ++ i)
      
if  (count(i)) printf( " %2d  " , i);
    printf(
" \n " );  
  }
  
}; 



void  ring()            // 约瑟夫环
{
  
const   int  N  =   17 ;    // N个人编号:1,2, ... N
   const   int  M  =   7 ;     // 报数:1到M,报到M的出列
  printf( "  N: %d   M: %d\n " , N, M);
  BIT
< N >  pt;
  
//  for (int i = 0; i < N; ++i) pt.update(i + 1);
   int  arr[N];
  
for  ( int  i  =   0 ; i  <  N;  ++ i) arr[i]  =   1 ;
  pt.init(arr, N);

  
for  ( int  j  =  N, k  =   0 ; j  >=   1 -- j) {
    k 
=  (k  +  M - 1 %  j;
    
int  t  =  pt.erase_nth(k  +   1 );
    printf(
"  turn: %2d  out: %2d   rest:   " , N  -  j, t);
    pt.show();
  }
  printf(
"  \n\n " );
}

int  ra( int  arr[],  int  len)  // 求逆序数-直接搜索
{
  
int  sum  =   0 ;
  
for  ( int  i  =   0 ; i  <  len  -   1 ++ i)
    
for  ( int  j  =  i  +   1 ; j  <  len;  ++ j)
      
if  (arr[i]  >  arr[j])  ++ sum;
  
return  sum;    
}

template
< int  N >
int  rb( int  arr[],  int  len)  // 求逆序数-使用树状数组
{
  BIT
< N >  pt;
  
int  sum  =   0 ;
  
for  ( int  i  =   0 ; i  <  len;  ++ i) {
    pt.update(arr[i] 
+   1 );
    sum 
+=  pt.gt(arr[i]  +   1 );
  }
  
return  sum;  
}


int  main()
{
  
int  arr[]  =  {  4 , 3 , 2 , 1 , 0 , 5 1 , 3 , 0 , 2 };
  
const   int  N  =   sizeof (arr)  /   sizeof (arr[ 0 ]);
  printf(
" %d %d\n\n " , ra(arr, N), rb < 6 > (arr, N));
  ring();
}

 

 

你可能感兴趣的:(树状数组)