树状数组的原理及应用 (BIT)

树状数组的原理及应用 (BIT)

1.原理图

树状数组的原理及应用 (BIT)_第1张图片

2.BIT相关函数

  • lowbit 运算

    • 含义:取x的二进制最右边的1和它右边所有的0,即2^k,其中k表示1之后0的个数。
    • 代码:

      int lorbit(int x)
      {
          return x & -x;
      }
  • C[x]的含义:

    • C[x] = (x - lowbit(x), x] = [x-lowbit(x) + 1, x]
    • C[x]的覆盖长度就是lowbit(x),注意树状数组的下标必须从1开始。
  • query 函数

    • 代码:

      int query(int x)
      {
      int sum = 0;
      for (int i = x; i > 0; i -= lowbit(i))
          sum += tr[i];
      return sum;
      }
    • 时间复杂度:O(logN)
    • 区间[x, y]内树的和,即A[x] + A[x -1] +...+ A[y],可以使用query(y) - query(x - 1)表示
  • add 函数

    • 代码:

      void add(int x, int v)
      {
          for (int i = x; i < N; i += lowbit(i))
              tr[i] += v;
      }
    • 时间复杂度:O(logN)
    • 这个过程是从右向左不断定位x的二进制最右边1左边0的过程。

3.BIT应用

  • 问题1:给定一个有N个正整数的A( N <= 100000,A[i] <= 100000),对序列中的每个数,求出序列中它左边比它小的树的个数。

    • 代码:

      #include  
      #include 
      #include 
      #include 
      
      using namespace std;
      
      const int N = 100010;
      
      int a[N];
      int tr[N];
      
      int lowbit(int x)
      {
          return x & -x;
      }
      
      void add(int x, int v)
      {
          for(int i = x; i < N; i += lowbit(i))
              tr[i] += v;
      }
      
      int query(int x)
      {
          int res = 0;
          for (int i = x; i > 0; i -= lowbit(i))
              res  += tr[i];
          return res;
      }
      
      int main()
      {
          int n;
          scanf("%d", &n);
          memset(tr, 0, sizeof tr);
          for (int i = 1; i <= n; i++)    
          {
              scanf("%d", &a[i]);
              add(a[i], 1);
          printf("%d\n", query(a[i] - 1));
          }
          return 0;
      }
  • 问题2:统计序列中在元素左边比该元素大的元素的个数。

    • 解法:

      query(N) - query(A[i]);
  • 问题3:如果A[i] > N,则需要对数组进行离散化操作,将任何不在合适区间的整数或者非整数都转换为不超过元素个数的整数。下述是针对“统计序列中在元素左边比该元素小的元素的个数”的问题给出的代码。

    • 代码:

      #include 
      #include 
      #include 
      #include 
      
      using namespace std;
      
      const int N = 100010;
      
      struct Node
      {
          int val;
          int pos;
      } temp[N];
      
      int A[N];
      int tr[N];
      
      int lowbit(int x)
      {
          return x & -x;
      }
      
      void add(int x, int v)
      {
          for (int i = x; i < N; i += lowbit(i))
              tr[i] += v;
      }
      
      int query(int x)
      {
          int sum = 0;
          for (int i = x; i > 0; i -= lowbit(i))
              sum += tr[i];
          return sum;
      }
      
      bool cmp(Node a, Node b)
      {
          return a.val < b.val;
      }
      
      int main()
      {
          int n;
          scanf("%d", &n);
      
          for (int i = 0; i < n; i ++)
              {
                  scanf("%d", &temp[i].val);
                  temp[i].pos = i;
              }
      
          //离散化
          sort (temp , temp + n, cmp);
          for (int i = 0; i < n; i ++)
              {
                  if ( i == 0 || temp[i].val != temp[i - 1].val)
                      {
                          A[temp[i].pos] = i + 1;
                      }
                  else
                      {
                          A[temp[i].pos] = A[temp[i - 1].pos];
                      }
              }
      
          //进入更新和求和操作
          for (int i = n - 1; i >= 0; i --)
              {
                  add(A[i], 1);
      
                  printf("%d\n", query(A[i] - 1));
              }
          return 0;
      }
      

你可能感兴趣的:(算法,c++)