区间问题通常我会用线段树,比较好理解;但树状数组常数实在太低了,有的题估计卡你,学不明白,只好打板
国内喜欢叫BIT(Binary-Indexed-Tree), 国外很多叫FrenwickTree
例题: 307. 区域和检索 - 数组可修改
树状数组最常用最擅长部分,当然,线段树能做
class BinIndexTree:
def __init__(self, size_or_nums): # 树状数组,下标需要从1开始
# 如果size 是数字,那就设置size和空数据;如果size是数组,那就是a
if isinstance(size_or_nums, int):
self.size = size_or_nums
self.c = [0 for _ in range(self.size + 5)]
# self.a = [0 for _ in range(self.size + 5)]
else:
self.size = len(size_or_nums)
# self.a = [0 for _ in range(self.size + 5)]
self.c = [0 for _ in range(self.size + 5)]
for i, v in enumerate(size_or_nums):
self.add_point(i + 1, v)
def add_point(self, i, v): # 单点增加,下标从1开始
# self.a[i] += v
while i <= self.size:
self.c[i] += v
i += i & -i
# def set_point(self, i, v): # 单点修改,下标从1开始 需要先计算差值,然后调用add
# self.add_point(i, v - self.a[i])
# self.a[i] = v
def sum_interval(self, l, r): # 区间求和,下标从1开始,计算闭区间[l,r]上的和
return self.sum_prefix(r) - self.sum_prefix(l - 1)
def sum_prefix(self, i): # 前缀求和,下标从1开始
s = 0
while i >= 1:
s += self.c[i]
# i -= i&-i
i &= i - 1
return s
def min_right(self, i):
"""寻找[i,size]闭区间上第一个正数(不为0的数),注意i是1-indexed。若没有返回size+1"""
p = self.sum_prefix(i)
if i == 1:
if p > 0:
return i
else:
if p > self.sum_prefix(i - 1):
return i
l, r = i, self.size + 1
while l + 1 < r:
mid = (l + r) >> 1
if self.sum_prefix(mid) > p:
r = mid
else:
l = mid
return r
def lowbit(self, x):
return x & -x
class NumArray:
def __init__(self, nums: List[int]):
self.tree = BinIndexTree(nums)
def update(self, index: int, val: int) -> None:
self.tree.set_point(index+1,val)
def sumRange(self, left: int, right: int) -> int:
return self.tree.sum_interval(left+1,right+1)
例题: 1589. 所有排列中的最大和
这题其实应该用差分数组,可以省一层log。思想就是树状数组的IUPQ模型。
树状数组
经典案例,要用差分数组理解:
这个实际上是用树状数组维护原数组的差分数组c[i]=a[i]-a[i-1]
求原数组的值a[i]实际上是差分数组的前缀sum(c[0]…c[i]),所以get a[i]可以用sum c[i]表示,
而原数组a区间[x,y]更新+v,产生的效果是:x位置比x-1位置大了v,y+1位置比y小了v;
对差分数组c来说,产生变化的就是c[x]增加了v,c[y+1]减小了v,因为c数组代表的是a中每个数比前一个数的差。
class BinIndexTreeUpdateInterval:
def __init__(self, size_or_nums ): # 树状数组,下标需要从1开始
# 如果size 是数字,那就设置size和空数据;如果size是数组,那就是a
if isinstance(size_or_nums, int):
self.size = size_or_nums
self.c = [0 for _ in range(self.size+5)]
else:
self.size = len(size_or_nums)
self.c = [0 for _ in range(self.size+5)]
for i,v in enumerate(size_or_nums):
self.add_interval(i+1,i+1,v)
def add_point(self,i,v): # 单点增加,下标从1开始;不支持直接调用,这里增加的是差分数组的单点
while i<=self.size :
self.c[i] += v
i += self.lowbit(i)
def sum_prefix(self,i): # 前缀求和,下标从1开始;不支持直接调用,这里求和的是差分数组的前缀和
s = 0
while i >= 1:
s += self.c[i]
i -= self.lowbit(i)
return s
def add_interval(self,l,r,v): # 区间加,下标从1开始,把[l,r]闭区间都加v
self.add_point(l,v)
self.add_point(r+1,-v)
def query_point(self,i): # 单点询问值,下标从1开始,返回i位置的值
return self.sum_prefix(i)
def lowbit(self,x):
return x&-x
class Solution:
def maxSumRangeQuery(self, nums: List[int], requests: List[List[int]]) -> int:
n = len(nums)
mod = 10**9+7
nums.sort(reverse=True)
tree = BinIndexTreeUpdateInterval(n+5)
for start,end in requests:
tree.add_interval(start+1,end+1,1)
freq = [0] * n
for i in range(n):
freq[i] = tree.query_point(i+1)
ans = 0
freq.sort(reverse=True)
for i in range(n):
if freq[i] ==0:
break
ans += freq[i]*nums[i]%mod
ans %= mod
return ans
题目: P3372 【模板】线段树 1
(式子①)
那么 sigma(a,n) = 式子①
=n*sigma(d,n) - sigma(d2,n)
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *
if sys.hexversion == 50924784:
sys.stdin = open('cfinput.txt')
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
MOD = 10 ** 9 + 7
"""https://www.luogu.com.cn/problem/P3372
模板题,区间加,区间求和。
"""
class BinIndexTreeRURQ:
def __init__(self, size_or_nums): # 树状数组,区间加区间求和,下标需要从1开始
# 如果size 是数字,那就设置size和空数据;如果size是数组,那就是a
if isinstance(size_or_nums, int):
self.size = size_or_nums
self.c = [0 for _ in range(self.size + 5)]
self.c2 = [0 for _ in range(self.size + 5)]
else:
self.size = len(size_or_nums)
self.c = [0 for _ in range(self.size + 5)]
self.c2 = [0 for _ in range(self.size + 5)]
for i, v in enumerate(size_or_nums):
self.add_interval(i + 1, i + 1, v)
def add_point(self, c, i, v): # 单点增加,下标从1开始;不支持直接调用,这里增加的是差分数组的单点,同步修改c2
while i <= self.size:
c[i] += v
i += -i&i
def sum_prefix(self, c, i): # 前缀求和,下标从1开始;不支持直接调用,这里求和的是差分数组的前缀和;传入c决定怎么计算,但是不要直接调用 无视吧
s = 0
while i >= 1:
s += c[i]
i -= -i&i
return s
def add_interval(self, l, r, v): # 区间加,下标从1开始,把[l,r]闭区间都加v
self.add_point(self.c, l, v)
self.add_point(self.c, r + 1, -v)
self.add_point(self.c2, l, (l-1)*v)
self.add_point(self.c2, r + 1, -v*r)
def sum_interval(self, l, r): # 区间求和,下标从1开始,返回闭区间[l,r]上的求和
return self.sum_prefix(self.c, r) * r - self.sum_prefix(self.c2, r) - self.sum_prefix(self.c, l - 1) * (
l - 1) + self.sum_prefix(self.c2, l - 1)
def query_point(self, i): # 单点询问值,下标从1开始,返回i位置的值
return self.sum_prefix(self.c, i)
def lowbit(self, x):
return x & -x
# ms
def solve(n, m, a, qs):
tree = BinIndexTreeRURQ(a)
ans = []
for q in qs:
if q[0] == 1:
l, r, x = q[1], q[2], q[3]
tree.add_interval(l, r, x)
elif q[0] == 2:
l, r = q[1], q[2]
ans.append(tree.sum_interval(l, r))
print('\n'.join(map(str, ans)))
if __name__ == '__main__':
n, m = RI()
a = RILST()
q = []
for _ in range(m):
q.append(RILST())
solve(n, m, a, q)
例题: CF522 D. Closest Equals
这是20220923的茶。
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
if sys.hexversion == 50924784:
sys.stdin = open('cfinput.txt')
# input = sys.stdin.readline
# input_int = sys.stdin.buffer.readline
# RI = lambda: map(int, input_int().split())
# RS = lambda: input().strip().split()
# RILST = lambda: list(RI())
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: sys.stdin.readline().strip().split()
RILST = lambda: list(RI())
MOD = 10 ** 9 + 7
"""https://codeforces.com/problemset/problem/522/D
输入 n(2≤n≤2e5) 和 m(1≤m≤2e5);然后输入一个长为 n 的数组 a(-1e9≤a[i]≤1e9),数组下标从 1 开始;然后输入 m 个询问,每个询问表示一个数组 a 内的闭区间 [L,R] (1≤L≤R≤n)。
对每个询问,输出区间内的相同元素下标之间的最小差值。如果区间内不存在相同元素,输出 -1。
输入
5 3
1 1 2 3 2
1 5
2 4
3 5
输出
1
-1
2
输入
6 5
1 2 1 3 2 3
4 6
1 3
2 5
2 4
1 6
输出
2
2
3
-1
2
"""
class ITreeNode:
__slots__ = ['val', 'l', 'r']
def __init__(self, l=None, r=None, v=0):
self.val, self.l, self.r = v, l, r
class IntervalTree:
def __init__(self):
self.root = ITreeNode()
def update_from_son(self, node):
node.val = min(node.l.val if node.l else inf, node.r.val if node.r else inf)
return node
def update_point(self, node, l, r, index, val):
if index < l or r < index:
return None
if not node:
node = ITreeNode(v=inf)
if l == r:
node.val = val
return node
mid = (l + r) // 2
if index <= mid:
node.l = self.update_point(node.l, l, mid, index, val)
else:
node.r = self.update_point(node.r, mid + 1, r, index, val)
return self.update_from_son(node)
def query(self, node, l, r, x, y):
if not node or y < l or r < x:
return inf
if x <= l and r <= y:
return node.val
mid = (l + r) // 2
s = inf
if x <= mid:
s = min(s, self.query(node.l, l, mid, x, y))
if mid < y:
s = min(s, self.query(node.r, mid + 1, r, x, y))
return s
class BinIndexTreeMax:
def __init__(self, size):
self.size = size
self.a = [-inf for _ in range(size + 5)]
self.h = self.a[:]
self.mx = -inf
def update(self, x, v):
if v > self.mx:
self.mx = v
a = self.a
h = self.h
a[x] = v
while x <= self.size:
if h[x] < v:
h[x] = v
else:
break
x += self.lowbit(x)
def query(self, l, r):
a = self.a
h = self.h
ans = a[r]
while l != r:
r -= 1
while r - self.lowbit(r) > l:
if ans < h[r]:
ans = h[r]
if ans == self.mx:
break
r -= self.lowbit(r)
# ans = min(ans, self.a[r])
if ans < a[r]:
ans = a[r]
if ans == self.mx:
break
return ans
def lowbit(self, x):
return x & -x
class BinIndexTreeMin:
def __init__(self, size):
self.size = size
self.a = [inf for _ in range(size + 5)]
self.h = self.a[:]
self.mn = inf
def update(self, x, v):
if v < self.mn:
self.mn = v
a = self.a
h = self.h
a[x] = v
while x <= self.size:
if h[x] > v:
h[x] = v
else:
break
x += self.lowbit(x)
def query(self, l, r):
a = self.a
h = self.h
ans = a[r]
while l != r:
r -= 1
while r - self.lowbit(r) > l:
if ans > h[r]:
ans = h[r]
if ans == self.mn:
break
r -= self.lowbit(r)
# ans = min(ans, self.a[r])
if ans > a[r]:
ans = a[r]
if ans == self.mn:
break
return ans
def lowbit(self, x):
return x & -x
# 2932 ms
def solve(n, m, a, q):
a = [0] + a
q = sorted([(l, r, i) for i, (l, r) in enumerate(q)], key=lambda x: x[1])
# tree = IntervalTree()
tree = BinIndexTreeMin(n + 5)
pre = {}
j = 1
ans = [-1] * m
for l, r, i in q:
while j <= r:
v = a[j]
if v in pre:
idx = pre[v]
d = j - idx
tree.update(idx, d)
pre[v] = j
j += 1
cur = tree.query(l, r)
# print(tree.a, tree.h, l, r, cur)
if cur < inf:
ans[i] = cur
# print(ans)
print('\n'.join(map(str, ans)))
if __name__ == '__main__':
n, m = RI()
a = RILST()
q = []
for _ in range(m):
q.append(RILST())
solve(n, m, a, q)
例题: 2407. 最长递增子序列 II
周赛T4,当时用线段树做的;实际测试线段树的表现甚至优于树状数组,奇怪极了。
class BinIndexTreeMax:
def __init__(self, size_or_nums ): # 树状数组,下标需要从1开始
# 如果size 是数字,那就设置size和空数组;如果size是数组,那就是a
if isinstance(size_or_nums, int):
self.size = size_or_nums
self.h = [-inf for _ in range(self.size+5)]
self.a = [-inf for _ in range(self.size+5)]
else:
self.size = len(size_or_nums)
self.a = [-inf for _ in range(self.size+5)]
self.h = [-inf for _ in range(self.size+5)]
for i,v in enumerate(size_or_nums):
self.set_point(i+1,v)
def set_point(self,x,v): # 单点修改,下标从1开始 修改原数组和h数组
self.a[x] = v
while x <= self.size:
# self.h[x] = max(self.h[x], self.a[lx])
if self.h[x] < v:
self.h[x] = v
x += (x&-x)
def query_interval_max(self,l,r): # 区间询问最大值,下标从1开始
ans = -inf
while l <= r:
# ans = max(self.a[r], ans)
if ans < self.a[r]:
ans = self.a[r]
r -= 1
while r - (r&-r) >= l:
# ans = max(self.h[r], ans)
if ans < self.h[r]:
ans = self.h[r]
r -= (r&-r)
return ans
def lowbit(self,x):
return x&-x
class Solution:
def lengthOfLIS(self, nums: List[int], k: int) -> int:
n = len(nums)
mx = max(nums)
tree = BinIndexTreeMax(mx)
ans = 0
for v in nums:
l = max(1,v-k)
r = v-1
ret = tree.query_interval_max(l,r)+1
if ret < 1:
ret = 1
tree.set_point(v,ret)
return tree.query_interval_max(1,mx)
例题: 6292. 子矩阵元素加 1
#二维树状数组,维护区域和
class BinTree2DIUPQ:
def __init__(self, m, n):
self.n = n
self.m = m
self.tree = [[0] * (n + 1) for _ in range(m + 1)]
def lowbit(self, x):
return x & (-x)
def _update_point(self, x, y, val):
m,n,tree = self.m,self.n,self.tree
while x <= m:
y1 = y
while y1 <= n:
tree[x][y1] += val
y1 += y1&-y1
x += x&-x
def _sum_prefix(self, x, y):
res = 0
tree = self.tree
while x > 0:
y1 = y
while y1 > 0:
res += tree[x][y1]
y1 &= y1-1
x &= x-1
return res
def add_interval(self,x1,y1,x2,y2,v):
self._update_point(x1,y1,v)
self._update_point(x2+1,y1,-v)
self._update_point(x1,y2+1,-v)
self._update_point(x2+1,y2+1,v)
def query_point(self,x,y):
return self._sum_prefix(x,y)
class Solution:
def rangeAddQueries(self, n: int, queries: List[List[int]]) -> List[List[int]]:
tree = BinTree2DIUPQ(n, n)
for x1, y1, x2, y2 in queries:
tree.add_interval(x1+1,y1+1,x2+1,y2+1,1)
res = [[0] * n for _ in range(n)]
for i in range(n):
for j in range(n):
res[i][j] = tree.query_point(i + 1, j + 1)
return res