树状数组模板

0X00 理解树状数组

没有学习过的同学可以看这个视频:树状数组

如果想非常顺利的写出这个模板,得记住下面这张图

树状数组模板_第1张图片

假设我们要求的是区间和。

其中绿色的部分,就是这个区间的和。假设我们要给 a[5] 添加一个值,那么要改变的区间有 t[5] t[6] t[8],而假设我们要求一个区间的和,比如 a[3] ~ a[7],按照前缀和的思想就是 sums[7] - sums[1]

sums[7] = t[4] + t[6] + t[7]

sums[1] = t[1]

所以树状数组要做的就是快速(logn)求这个 t 数组

0X01 模板

「树状数组」只有两个操作(数组下标从 1 开始):

  • 查询
  • 添加

一维前缀和

def lowbit(x):
    return x & -x
    
def add(i, x):
    while i <= n:
        tr[i] += x
        i += lowbit(i)
        
def ask(i):
    res = 0
    while i > 0:
        res += tr[i]
        i -= lowbit(i)
    return res

下标从 1 开始,只要能记住那张图,知道 add 的时候是 +, ask 的时候是 -。就很容易写出这个模板

二维树状数组

与一维树状数组基本一样:模板参考我下面的第二道题目

0X00 相关题目

683. K 个空花盆 注意树状数组的边界

class Solution:
    def kEmptySlots(self, bulbs: List[int], k: int) -> int:
        n = len(bulbs)
        tr = [0] * (n+1)

        def lowbit(x):
            return x & -x
        
        def add(i, x):
            while i <= n:
                tr[i] += x
                i += lowbit(i)
        
        def ask(i):
            res = 0
            if i > n: i = n
            while i > 0:
                res += tr[i]
                i -= lowbit(i)
            return res
        
        def query(l, r):
            n1 = ask(l-1)
            n2 = ask(r)
            return True if n2 - n1 >= 1 else False

        d = 0
        for f in bulbs:
            d += 1
            add(f, 1)
            if query(f+k+1, f+k+1) and not query(f+1, f+k):
                return d
            if query(f-k-1, f-k-1) and not query(f-k, f-1):
                return d

        return -1

308. 二维区域和检索 - 可变

def lowbit(x):
    return x & -x

class NumMatrix:

    def _add(self, x, y, c):
        m, n = self.m, self.n
        i = x 
        while i <= m:
            j = y
            while j <= n:
                self.tr[i][j] += c
                j += lowbit(j)
            i += lowbit(i)
    
    def _ask(self, x, y):
        res = 0
        i = x
        while i > 0:
            j = y
            while j > 0:
                res += self.tr[i][j]
                j -= lowbit(j)
            i -= lowbit(i)
        return res

    def __init__(self, matrix: List[List[int]]):
        if not len(matrix): return
        self.m, self.n = len(matrix), len(matrix[0])
        m, n = self.m, self.n
        self.tr = [[0] * (n+1) for _ in range(m+1)]
        self.ma = matrix
        for i in range(self.m):
            for j in range(self.n):
                self._add(i+1, j+1, matrix[i][j])

    def update(self, x: int, y: int, v: int) -> None:
        delta = v - self.ma[x][y]
        self.ma[x][y] = v
        self._add(x+1, y+1, delta)
        

    def sumRegion(self, x1: int, y1: int, x2: int, y2: int) -> int:
        return self._ask(x2+1, y2+1) + self._ask(x1, y1) - self._ask(x2+1, y1) - self._ask(x1, y2+1)
        


# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)

你可能感兴趣的:(树状数组模板)