[python刷题模板] 树状数组

[python刷题模板] 树状数组 BIT

    • 一、 算法&数据结构
      • 1. 描述
      • 2. 复杂度分析
      • 3. 常见应用
      • 4. 常用优化
    • 二、 模板代码
      • 1. 单点赋值(增加),区间求和(PURQ)
      • 2. 区间更新,单点询值(RUPQ)
      • 3.区间更新,区间求和(RURQ)
      • 5. 单点更新区间求极值
      • 6. 单点赋值,区间询问最大(LIS II)
      • 7. 二维树状数组(IUPQ)
    • 三、其他

一、 算法&数据结构

1. 描述

区间问题通常我会用线段树,比较好理解;但树状数组常数实在太低了,有的题估计卡你,学不明白,只好打板
国内喜欢叫BIT(Binary-Indexed-Tree), 国外很多叫FrenwickTree
  • 树状数组的核心是:i的父节点是i+lowbit(i)
  • [python刷题模板] 树状数组_第1张图片

2. 复杂度分析

  1. 查询query, O(log2n)
  2. 更新update,O(log2n)

3. 常见应用

  1. 单点更新,区间求和
  2. 区间更新,单点求和
  3. 区间更新,区间求和
  4. 区间更新,单点求极值

4. 常用优化

  1. lowerbit() *O(1)*解法,需要理解补码知识

二、 模板代码

1. 单点赋值(增加),区间求和(PURQ)

例题: 307. 区域和检索 - 数组可修改
树状数组最常用最擅长部分,当然,线段树能做

class BinIndexTree:
    def __init__(self, size_or_nums):  # 树状数组,下标需要从1开始
        # 如果size 是数字,那就设置size和空数据;如果size是数组,那就是a
        if isinstance(size_or_nums, int):
            self.size = size_or_nums
            self.c = [0 for _ in range(self.size + 5)]
            # self.a = [0 for _ in range(self.size + 5)]
        else:
            self.size = len(size_or_nums)
            # self.a = [0 for _ in range(self.size + 5)]
            self.c = [0 for _ in range(self.size + 5)]
            for i, v in enumerate(size_or_nums):
                self.add_point(i + 1, v)

    def add_point(self, i, v):  # 单点增加,下标从1开始
        # self.a[i] += v
        while i <= self.size:
            self.c[i] += v
            i += i & -i

    # def set_point(self, i, v):  # 单点修改,下标从1开始 需要先计算差值,然后调用add
    #     self.add_point(i, v - self.a[i])
    #     self.a[i] = v

    def sum_interval(self, l, r):  # 区间求和,下标从1开始,计算闭区间[l,r]上的和
        return self.sum_prefix(r) - self.sum_prefix(l - 1)

    def sum_prefix(self, i):  # 前缀求和,下标从1开始
        s = 0
        while i >= 1:
            s += self.c[i]
            # i -= i&-i
            i &= i - 1
        return s

    def min_right(self, i):
        """寻找[i,size]闭区间上第一个正数(不为0的数),注意i是1-indexed。若没有返回size+1"""
        p = self.sum_prefix(i)
        if i == 1:
            if p > 0:
                return i
        else:
            if p > self.sum_prefix(i - 1):
                return i

        l, r = i, self.size + 1
        while l + 1 < r:
            mid = (l + r) >> 1
            if self.sum_prefix(mid) > p:
                r = mid
            else:
                l = mid
        return r

    def lowbit(self, x):
        return x & -x


class NumArray:

    def __init__(self, nums: List[int]):
        self.tree = BinIndexTree(nums)
        
    def update(self, index: int, val: int) -> None:    
        self.tree.set_point(index+1,val)       

    def sumRange(self, left: int, right: int) -> int:        
        return self.tree.sum_interval(left+1,right+1)        

2. 区间更新,单点询值(RUPQ)

例题: 1589. 所有排列中的最大和
这题其实应该用差分数组,可以省一层log。思想就是树状数组的IUPQ模型。

树状数组经典案例,要用差分数组理解:
这个实际上是用树状数组维护原数组的差分数组c[i]=a[i]-a[i-1]
求原数组的值a[i]实际上是差分数组的前缀sum(c[0]…c[i]),所以get a[i]可以用sum c[i]表示,
而原数组a区间[x,y]更新+v,产生的效果是:x位置比x-1位置大了v,y+1位置比y小了v;
对差分数组c来说,产生变化的就是c[x]增加了v,c[y+1]减小了v,因为c数组代表的是a中每个数比前一个数的差。

  • sum[i]代替get[i],单点求值
  • 两步add(l,v)和add(r+1,-v),区间更新
class BinIndexTreeUpdateInterval:
    def __init__(self, size_or_nums ):  # 树状数组,下标需要从1开始
        # 如果size 是数字,那就设置size和空数据;如果size是数组,那就是a
        if isinstance(size_or_nums, int):
            self.size = size_or_nums
            self.c = [0 for _ in range(self.size+5)]
        else:
            self.size = len(size_or_nums)
            self.c = [0 for _ in range(self.size+5)]
            for i,v in enumerate(size_or_nums):
                self.add_interval(i+1,i+1,v)
    def add_point(self,i,v):  # 单点增加,下标从1开始;不支持直接调用,这里增加的是差分数组的单点
        
        while i<=self.size :
            self.c[i] += v
            i += self.lowbit(i)
  
    def sum_prefix(self,i):  # 前缀求和,下标从1开始;不支持直接调用,这里求和的是差分数组的前缀和
        s = 0
        while i >= 1:
            s += self.c[i]
            i -= self.lowbit(i)
        return s
    
    def add_interval(self,l,r,v):  # 区间加,下标从1开始,把[l,r]闭区间都加v
   		self.add_point(l,v)
   		self.add_point(r+1,-v)
    def query_point(self,i):  # 单点询问值,下标从1开始,返回i位置的值
        return self.sum_prefix(i)
    def lowbit(self,x):
        return x&-x

class Solution:
    def maxSumRangeQuery(self, nums: List[int], requests: List[List[int]]) -> int:
        n = len(nums)
        mod = 10**9+7
        nums.sort(reverse=True)
        tree = BinIndexTreeUpdateInterval(n+5)
        for start,end in requests:
            tree.add_interval(start+1,end+1,1)
        freq = [0] * n
        for i in range(n):
            freq[i] = tree.query_point(i+1)
        ans = 0
        freq.sort(reverse=True)
        for i in range(n):
            if freq[i] ==0:
                break
            ans += freq[i]*nums[i]%mod
            ans %= mod
        
        return ans

3.区间更新,区间求和(RURQ)

题目: P3372 【模板】线段树 1

  • 记sigma(r,i)表示r数组的前i项和。
  • 我们知道,在区间求和的BIT中,实际维护了原数组a的差分数组d。
  • 于是有a[n] = d[1]+d[2]+…+d[n]
  • 观察式子:
    a[1]+a[2]+…+a[n]
    = (d[1]) + (d[1]+d[2]) + … + (d[1]+d[2]+…+d[n])
    = n * d[1] + (n-1) * d[2] +… +d[n]
    = n * (d[1]+d[2]+…+d[n]) - (0 * d[1]+1 * d[2]+…+(n-1) * d[n]) (式子①)
    维护一个数组d2[n],其中d2[i] = (i-1)*d[i]
    每当修改c的时候,就同步修改一下d2,这样复杂度就不会改变

那么 sigma(a,n) = 式子①=n*sigma(d,n) - sigma(d2,n)

import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *

if sys.hexversion == 50924784:
    sys.stdin = open('cfinput.txt')


RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())

MOD = 10 ** 9 + 7
"""https://www.luogu.com.cn/problem/P3372
模板题,区间加,区间求和。
"""


class BinIndexTreeRURQ:
    def __init__(self, size_or_nums):  # 树状数组,区间加区间求和,下标需要从1开始
        # 如果size 是数字,那就设置size和空数据;如果size是数组,那就是a
        if isinstance(size_or_nums, int):
            self.size = size_or_nums
            self.c = [0 for _ in range(self.size + 5)]
            self.c2 = [0 for _ in range(self.size + 5)]
        else:
            self.size = len(size_or_nums)
            self.c = [0 for _ in range(self.size + 5)]
            self.c2 = [0 for _ in range(self.size + 5)]
            for i, v in enumerate(size_or_nums):
                self.add_interval(i + 1, i + 1, v)

    def add_point(self, c, i, v):  # 单点增加,下标从1开始;不支持直接调用,这里增加的是差分数组的单点,同步修改c2
        while i <= self.size:
            c[i] += v
            i += -i&i

    def sum_prefix(self, c, i):  # 前缀求和,下标从1开始;不支持直接调用,这里求和的是差分数组的前缀和;传入c决定怎么计算,但是不要直接调用 无视吧
        s = 0
        while i >= 1:
            s += c[i]
            i -= -i&i
        return s

    def add_interval(self, l, r, v):  # 区间加,下标从1开始,把[l,r]闭区间都加v
        self.add_point(self.c, l, v)
        self.add_point(self.c, r + 1, -v)
        self.add_point(self.c2, l, (l-1)*v)
        self.add_point(self.c2, r + 1, -v*r)

    def sum_interval(self, l, r):  # 区间求和,下标从1开始,返回闭区间[l,r]上的求和
        return self.sum_prefix(self.c, r) * r - self.sum_prefix(self.c2, r) - self.sum_prefix(self.c, l - 1) * (
                l - 1) + self.sum_prefix(self.c2, l - 1)

    def query_point(self, i):  # 单点询问值,下标从1开始,返回i位置的值
        return self.sum_prefix(self.c, i)

    def lowbit(self, x):
        return x & -x


#  	 ms
def solve(n, m, a, qs):
    tree = BinIndexTreeRURQ(a)
    ans = []
    for q in qs:
        if q[0] == 1:
            l, r, x = q[1], q[2], q[3]
            tree.add_interval(l, r, x)
        elif q[0] == 2:
            l, r = q[1], q[2]
            ans.append(tree.sum_interval(l, r))
    print('\n'.join(map(str, ans)))


if __name__ == '__main__':
    n, m = RI()
    a = RILST()
    q = []
    for _ in range(m):
        q.append(RILST())
    solve(n, m, a, q)

5. 单点更新区间求极值

例题: CF522 D. Closest Equals
这是20220923的茶。

  • 相同元素组成可以看成线段,问题转化为求区间内最短线段。
  • 询问离线,按r排序,计算每个线段长度,记在左端点上。
  • 查询区间最小值即可。
  • 正常用线段树,但是py线段树过不了,于是上网查了个树状数组的模板
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os

if sys.hexversion == 50924784:
    sys.stdin = open('cfinput.txt')

# input = sys.stdin.readline
# input_int = sys.stdin.buffer.readline
# RI = lambda: map(int, input_int().split())
# RS = lambda: input().strip().split()
# RILST = lambda: list(RI())

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: sys.stdin.readline().strip().split()
RILST = lambda: list(RI())

MOD = 10 ** 9 + 7
"""https://codeforces.com/problemset/problem/522/D

输入 n(2≤n≤2e5) 和 m(1≤m≤2e5);然后输入一个长为 n 的数组 a(-1e9≤a[i]≤1e9),数组下标从 1 开始;然后输入 m 个询问,每个询问表示一个数组 a 内的闭区间 [L,R] (1≤L≤R≤n)。

对每个询问,输出区间内的相同元素下标之间的最小差值。如果区间内不存在相同元素,输出 -1。

输入
5 3
1 1 2 3 2
1 5
2 4
3 5
输出
1
-1
2

输入
6 5
1 2 1 3 2 3
4 6
1 3
2 5
2 4
1 6
输出
2
2
3
-1
2
"""


class ITreeNode:
    __slots__ = ['val', 'l', 'r']

    def __init__(self, l=None, r=None, v=0):
        self.val, self.l, self.r = v, l, r


class IntervalTree:
    def __init__(self):
        self.root = ITreeNode()

    def update_from_son(self, node):
        node.val = min(node.l.val if node.l else inf, node.r.val if node.r else inf)
        return node

    def update_point(self, node, l, r, index, val):
        if index < l or r < index:
            return None
        if not node:
            node = ITreeNode(v=inf)
        if l == r:
            node.val = val
            return node
        mid = (l + r) // 2
        if index <= mid:
            node.l = self.update_point(node.l, l, mid, index, val)
        else:
            node.r = self.update_point(node.r, mid + 1, r, index, val)

        return self.update_from_son(node)

    def query(self, node, l, r, x, y):
        if not node or y < l or r < x:
            return inf
        if x <= l and r <= y:
            return node.val
        mid = (l + r) // 2
        s = inf
        if x <= mid:
            s = min(s, self.query(node.l, l, mid, x, y))
        if mid < y:
            s = min(s, self.query(node.r, mid + 1, r, x, y))
        return s


class BinIndexTreeMax:
    def __init__(self, size):
        self.size = size
        self.a = [-inf for _ in range(size + 5)]
        self.h = self.a[:]
        self.mx = -inf

    def update(self, x, v):
        if v > self.mx:
            self.mx = v
        a = self.a
        h = self.h
        a[x] = v
        while x <= self.size:
            if h[x] < v:
                h[x] = v
            else:
                break
            x += self.lowbit(x)

    def query(self, l, r):
        a = self.a
        h = self.h
        ans = a[r]
        while l != r:
            r -= 1
            while r - self.lowbit(r) > l:
                if ans < h[r]:
                    ans = h[r]
                    if ans == self.mx:
                        break
                r -= self.lowbit(r)
            # ans = min(ans, self.a[r])
            if ans < a[r]:
                ans = a[r]
            if ans == self.mx:
                break
        return ans

    def lowbit(self, x):
        return x & -x


class BinIndexTreeMin:
    def __init__(self, size):
        self.size = size
        self.a = [inf for _ in range(size + 5)]
        self.h = self.a[:]
        self.mn = inf

    def update(self, x, v):
        if v < self.mn:
            self.mn = v
        a = self.a
        h = self.h
        a[x] = v
        while x <= self.size:
            if h[x] > v:
                h[x] = v
            else:
                break
            x += self.lowbit(x)

    def query(self, l, r):
        a = self.a
        h = self.h
        ans = a[r]
        while l != r:
            r -= 1
            while r - self.lowbit(r) > l:
                if ans > h[r]:
                    ans = h[r]
                    if ans == self.mn:
                        break
                r -= self.lowbit(r)
            # ans = min(ans, self.a[r])
            if ans > a[r]:
                ans = a[r]
            if ans == self.mn:
                break
        return ans

    def lowbit(self, x):
        return x & -x


# 	2932  ms
def solve(n, m, a, q):
    a = [0] + a
    q = sorted([(l, r, i) for i, (l, r) in enumerate(q)], key=lambda x: x[1])
    # tree = IntervalTree()
    tree = BinIndexTreeMin(n + 5)
    pre = {}
    j = 1
    ans = [-1] * m
    for l, r, i in q:
        while j <= r:
            v = a[j]
            if v in pre:
                idx = pre[v]
                d = j - idx
                tree.update(idx, d)

            pre[v] = j
            j += 1
        cur = tree.query(l, r)
        # print(tree.a, tree.h, l, r, cur)
        if cur < inf:
            ans[i] = cur

    # print(ans)
    print('\n'.join(map(str, ans)))


if __name__ == '__main__':
    n, m = RI()
    a = RILST()
    q = []
    for _ in range(m):
        q.append(RILST())

    solve(n, m, a, q)

6. 单点赋值,区间询问最大(LIS II)

例题: 2407. 最长递增子序列 II
周赛T4,当时用线段树做的;实际测试线段树的表现甚至优于树状数组,奇怪极了。


class BinIndexTreeMax:
    def __init__(self, size_or_nums ):  # 树状数组,下标需要从1开始
        # 如果size 是数字,那就设置size和空数组;如果size是数组,那就是a
        if isinstance(size_or_nums, int):
            self.size = size_or_nums
            self.h = [-inf for _ in range(self.size+5)]
            self.a = [-inf for _ in range(self.size+5)]
        else:
            self.size = len(size_or_nums)
            self.a = [-inf for _ in range(self.size+5)]
            self.h = [-inf for _ in range(self.size+5)]
            for i,v in enumerate(size_or_nums):
                self.set_point(i+1,v)
   

    def set_point(self,x,v):  # 单点修改,下标从1开始 修改原数组和h数组    
        self.a[x] = v        
        while x <= self.size:                
            # self.h[x] = max(self.h[x], self.a[lx])
            if self.h[x] < v:
                self.h[x] = v
            x += (x&-x)
    def query_interval_max(self,l,r):  # 区间询问最大值,下标从1开始
        ans = -inf
        while l <= r:
            # ans = max(self.a[r], ans)
            if ans < self.a[r]:
                ans = self.a[r]
            r -= 1
            while r - (r&-r) >= l:
                # ans = max(self.h[r], ans)
                if ans < self.h[r]:
                    ans = self.h[r]
                r -= (r&-r)
        return ans
  
    def lowbit(self,x):
        return x&-x

    
class Solution:
    def lengthOfLIS(self, nums: List[int], k: int) -> int:
        n = len(nums)
        mx = max(nums)
        tree = BinIndexTreeMax(mx)
        ans = 0
        for v in nums:
            l = max(1,v-k)
            r = v-1
            ret = tree.query_interval_max(l,r)+1
            if ret < 1:
                ret = 1
            tree.set_point(v,ret)        

        return tree.query_interval_max(1,mx)

7. 二维树状数组(IUPQ)

例题: 6292. 子矩阵元素加 1

  • 周赛T2,这题卡一维树状数组;但是可以差分过;可以二维树状数或二维差分过。
#二维树状数组,维护区域和
class BinTree2DIUPQ:
    def __init__(self, m, n):
        self.n = n
        self.m = m
        self.tree = [[0] * (n + 1) for _ in range(m + 1)]

    def lowbit(self, x):
        return x & (-x)

    def _update_point(self, x, y, val):
        m,n,tree = self.m,self.n,self.tree
        while x <= m:
            y1 = y
            while y1 <= n:
                tree[x][y1] += val
                y1 += y1&-y1
            x += x&-x

    def _sum_prefix(self, x, y):
        res = 0
        tree = self.tree
        while x > 0:
            y1 = y
            while y1 > 0:
                res += tree[x][y1]
                y1 &= y1-1
            x &= x-1
        return res
    
    def add_interval(self,x1,y1,x2,y2,v):
        self._update_point(x1,y1,v)
        self._update_point(x2+1,y1,-v)
        self._update_point(x1,y2+1,-v)
        self._update_point(x2+1,y2+1,v)

    def query_point(self,x,y):
        return self._sum_prefix(x,y)
        


class Solution:
    def rangeAddQueries(self, n: int, queries: List[List[int]]) -> List[List[int]]:
        tree = BinTree2DIUPQ(n, n)
        for x1, y1, x2, y2 in queries:
            tree.add_interval(x1+1,y1+1,x2+1,y2+1,1)
        res = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                res[i][j] = tree.query_point(i + 1, j + 1)
        return res

三、其他

  1. 树状数组还有有很多更巧妙的用法,更多的不会了,参照博文吧。
    https://blog.csdn.net/WhereIsHeroFrom/article/details/113598114

你可能感兴趣的:(python刷题模板,python,leetcode,算法)