11、树状数组(binary_indexed_tree)

1、简介

所谓的Binary Indexed Tree,首先需要明确它其实并不是一棵树。Binary Indexed Tree事实上是将根据数字的二进制表示来对数组中的元素进行逻辑上的分层存储。

树状数组的核心思想:

  • 每个元素是原数组中一个或多个连续元素的和
  • 在进行连续求和操作 a[1]+...a[n] 时,只需求树状数组中某几个元素的和即可,时间复杂度为 O(logn)
  • 在进行修改某个元素 a[i] 时,只需要修改树状数组中某几个元素的和即可,时间复杂度为 O(logn)

下面就是一个树状数组的示意图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4UxKD39g-1577762310267)(./pic/bit.png)]

解释如下:

(1)数组

  • a[]: 保存原始数据的数组。
  • e[]: 树状数组,其中的任意一个元素e[i]可能是一个或多个 a 数组中元素的和。如:e[2]=a[1]+a[2]; e[3]=a[3]; e[4]=a[1]+a[2]+a[3]+a[4]。

(2)e[i]的值

如果数字 i 的二进制表示中末尾有 k 个连续的 0,则 e[i] 是数组 a2^k 个元素的和。即 e[i]=a[i-2^k+1]+a[i-2^k+2]+...+a[i-1]+a[i]

示例:

4=100(2)  e[4]=a[1]+a[2]+a[3]+a[4];
6=110(2)  e[6]=a[5]+a[6]
7=111(2)  e[7]=a[7]

2^k 的计算方法:

  • 2^k = (i & (-i)); (利用机器补码特性)
  • 2^k = (i & (i^(i-1));

(3)后继、前驱

后继:可以理解为当前节点的父节点。是离他最近的,且编号末尾连续 0 比他多的就是他的父节点。

  • 如**e[4] = e[2]+e[3]+a[4]** = a[1]+a[2]+a[3]+a[4]e[2]e[3]的后继就是e[4]
  • 后继主要是用来计算e数组,将当前已经计算出的e[i]添加到他们后继中。

前驱:节点前驱的编号即为比自己小的,最近的,最末连续 0 比自己多的节点。

  • 如:Sum(7)=a[1]+...+a[7]=e[7]+e[6]+e[4]。(e[7]的前驱是e[6], e[6]的前驱是e[4])
  • 前驱主要是在计算连续和时,避免重复添加元素。

前驱后继的计算

  • lowbit(i) = ( (i-1) ^ i) & i(i & (-i))

  • 节点e[i]的前驱为 e[ i - lowbit(i) ]

  • 节点e[i]的前驱为 e[ i + lowbit(i) ]

2、复杂度分析

根据上面的分析,我们可以看出,对于长度为n的数组,单个updateprefixSum操作最多需要访问logn的元素,也就是说单个updateprefixSum操作的时间复杂度均为O(logn)

构建Binary Indexed Tree的时间复杂度为O(nlogn)或者O(n),取决于我们使用哪种算法。

Binary Indexed TreePython 实现:

# -*- coding: utf-8 -*-
"""
    Description: 树状数组
"""


class NumArray(object):

    def __init__(self, arr):
        self.arr = arr
        self.tree = [None] * (len(self.arr) + 1)
        self.build()

    def build(self):
        """构建树状数组"""
        for i in range(len(self.tree)):
            total = 0
            low_bit = i & ((i - 1) ^ i)
            j = i
            while j > i - low_bit:
                total = total + self.arr[j - 1]
                j -= 1
            self.tree[i] = total

    def update(self, i, value):
        """更新 i 位置的值为 value"""
        temp = value - self.arr[i]
        self.arr[i] = value
        i += 1
        while i < len(self.tree):
            self.tree[i] += temp
            i = i + (i & ((i - 1) ^ i))

    def sum_range(self, i, j):
        """求 [i, j] 区间内的和"""
        return self.get_sum(j) - self.get_sum(i)

    def get_sum(self, i):
        total = 0
        i += 1
        while i > 0:
            total += self.tree[i]
            i = i - ((i - 1) ^ i)
        return total

    def print(self):
        print(self.tree)


if __name__ == '__main__':
    array = [1, 7, 3, 0, 5, 8, 3, 2, 6, 2, 1, 1, 4, 5]

    na = NumArray(array)
    na.print()
# [0, 1, 8, 3, 11, 5, 13, 3, 29, 6, 8, 1, 10, 4, 9]
# array 的下标是从 0 开始的,tree 的下标从 1 开始,因此 tree 的长度是 array 的长度 + 1

你可能感兴趣的:(Python3,数据结构与算法)