线段树通常用来维护区间询问,通过二分的形式对数组分块,维护一个个小块上的属性(极值、求和等),用以实现O(logn)时间的查询和更新
例题: 699. 掉落的方块
用线段树维护x轴上每个线段(区间)的最大高度
class IntervalTree:
def __init__(self, size):
self.size = size
self.interval_tree = [0 for _ in range(size*4)]
self.lazys = [0 for _ in range(size*4)]
def give_lay_to_son(self,p,l,r):
interval_tree = self.interval_tree
lazys = self.lazys
if lazys[p] == 0:
return
mid = (l+r)//2
interval_tree[p*2] = lazys[p]
interval_tree[p*2+1] = lazys[p]
lazys[p*2] = lazys[p]
lazys[p*2+1] = lazys[p]
lazys[p] = 0
def update(self,p,l,r,x,y,val):
"""
把[x,y]区域全变成val
"""
if y < l or r < x:
return
interval_tree = self.interval_tree
lazys = self.lazys
if x <= l and r<=y:
interval_tree[p] = val
lazys[p] = val
return
self.give_lay_to_son(p,l,r)
mid = (l+r)//2
if x <= mid:
self.update(p*2,l,mid,x,y,val)
if mid < y:
self.update(p*2+1,mid+1,r,x,y,val)
interval_tree[p] = max(interval_tree[p*2], interval_tree[p*2+1])
def query(self,p,l,r,x,y):
"""
查找x,y区间的最大值 """
if y < l or r < x:
return 0
if x<=l and r<=y:
return self.interval_tree[p]
self.give_lay_to_son(p,l,r)
mid = (l+r)//2
s = 0
if x <= mid:
s = max(s,self.query(p*2,l,mid,x,y))
if mid < y:
s = max(s,self.query(p*2+1,mid+1,r,x,y))
return s
class Solution:
def fallingSquares(self, positions: List[List[int]]) -> List[int]:
n = len(positions)
hashes = [left for left,_ in positions] + [left+side for left,side in positions]
hashes = sorted(list(set(hashes)))
# 用线段树维护x轴区间最大值,记录每个点的高度:比如[1,2]这个方块,会使线段[1,2]闭区间这个线段上的每个高度都变成2
# 落下一个新方块时,查询它的底边所在线段[x,y]的最大高度h,这个方块会落在这个高度h,把新高度h+side插入线段树[x,y]的部分
# 每次插入结束,树根存的高度就是当前最大高度
# 由于数据范围大1 <= lefti <= 108,需要散列化
# 散列化的值有left和right(线段短点)
# print(hashes)
tree_size = len(hashes)
tree = IntervalTree(tree_size)
heights = []
for left,d in positions:
right = left + d
l = bisect_left(hashes,left)
r = bisect_left(hashes,right)
h = tree.query(1,1,tree_size,l+1,r)
tree.update(1,1,tree_size,l+1,r,h+d)
heights.append(tree.interval_tree[1])
return heights
链接: 850. 矩形面积 II
线段树
经典案例,涉及离散化
,扫描线
,细节很多,非常难写
class IntervalTreeNode:
def __init__(self, len,cover):
self.len = len
self.cover = cover
class IntervalTree:
def __init__(self, size,ys=None):
self.size = size
# self.interval_tree = [IntervalTreeNode(0,0)]*(size*4) ## 这个地方wa了很久,不能这么写,一直更新一个实例
self.interval_tree = [IntervalTreeNode(0,0) for _ in range(size*4)]
self.ys=ys
def update_from_son(self,p,l,r):
interval_tree = self.interval_tree
pn = interval_tree[p]
if pn.cover > 0:
pn.len = self.ys[r-1]-self.ys[l-1]
else:
if l+1 ==r:
pn.len = 0
else:
pn.len = interval_tree[p*2].len + interval_tree[p*2+1].len
def insert(self,p,l,r,x,y,cover):
if y < l or r < x:
return
interval_tree = self.interval_tree
if x<=l and r<=y:
interval_tree[p].cover += cover
self.update_from_son(p,l,r)
return
mid = (l+r)//2
if x < mid:
self.insert(p*2,l,mid,x,y,cover)
if y > mid:
self.insert(p*2+1,mid,r,x,y,cover)
self.update_from_son(p,l,r)
class LineY:
def __init__(self,x,y1,y2,cover):
self.x = x
self.y1 = y1
self.y2 = y2
self.cover = cover
class Solution:
def rectangleArea(self, rectangles: List[List[int]]) -> int:
lines = [] # 所有竖线线段
ys = set() # 离散化
for x1,y1,x2,y2 in rectangles:
lines.append(LineY(x1,y1,y2,1))
lines.append(LineY(x2,y1,y2,-1))
ys.add(y1)
ys.add(y2)
lines.sort(key=lambda x:x.x)
line_count = len(lines)
ys = list(ys)
ys.sort()
interval_tree = IntervalTree(line_count,ys=ys)
ans = 0
mod = int(1e9+7)
for i in range(0,line_count):
line = lines[i]
# print(line.x,line.y1,line.y2,line.cover)
y1 = bisect_left(ys,line.y1)
y2 = bisect_left(ys,line.y2)
# if y1==y2:
# continue
if i >0:
ans += (line.x-lines[i-1].x) * interval_tree.interval_tree[1].len
ans %= mod
interval_tree.insert(1,1,len(ys),y1+1,y2+1,line.cover)
return ans
链接: 307. 区域和检索 - 数组可修改
线段树
经典案例,比区间求极值麻烦一点点
class IntervalTree:
def __init__(self, size,nums=None):
self.size = size
self.nums = nums
self.interval_tree = [0]*(size*4)
if nums:
self.build_tree(1,1,size)
def build_tree(self,p,l,r):
interval_tree = self.interval_tree
nums = self.nums
if l == r:
interval_tree[p] = nums[l-1]
return
mid = (l+r)//2
self.build_tree(p*2,l,mid)
self.build_tree(p*2+1,mid+1,r)
interval_tree[p] = interval_tree[p*2]+interval_tree[p*2+1]
def add_point(self,p,l,r,index,add):
if index < l or r < index:
return
interval_tree = self.interval_tree
interval_tree[p] += add
if l == r:
return
mid = (l+r)//2
if index <= mid:
self.add_point(p*2,l,mid,index,add)
else:
self.add_point(p*2+1,mid+1,r,index,add)
def sum_interval(self,p,l,r,x,y):
if y < l or r < x:
return 0
interval_tree = self.interval_tree
if x<=l and r<=y:
return interval_tree[p]
mid = (l+r)//2
s = 0
if x <= mid:
s += self.sum_interval(p*2,l,mid,x,y)
if mid < y:
s += self.sum_interval(p*2+1,mid+1,r,x,y)
return s
class NumArray:
def __init__(self, nums: List[int]):
self.size = len(nums)
self.nums = nums
self.interval_tree = IntervalTree(self.size ,nums)
def update(self, index: int, val: int) -> None:
add = val - self.nums[index]
self.nums[index] = val
self.interval_tree.add_point(1,1,self.size,index+1,add)
def sumRange(self, left: int, right: int) -> int:
return self.interval_tree.sum_interval(1,1,self.size,left+1,right+1)
链接: 327. 区间和的个数
线段树
这题麻烦一点,求区间内数字数量,每个数初始化为0,插入时候+1,计数转化为求和
参考链接: [LeetCode解题报告]327. 区间和的个数
class IntervalTree:
def __init__(self, size):
self.size = size
self.interval_tree = [0 for _ in range(size*4)]
def insert(self,p,l,r,index):
if index < l or r < index:
return
interval_tree = self.interval_tree
if l == r:
interval_tree[p] += 1
return
mid = (l+r)//2
if index <= mid:
self.insert(p*2,l,mid,index)
else:
self.insert(p*2+1,mid+1,r,index)
interval_tree[p] = interval_tree[p*2]+interval_tree[p*2+1]
def query(self,p,l,r,x,y):
if y < l or r < x:
return 0
if x<=l and r<=y:
return self.interval_tree[p]
mid = (l+r)//2
s = 0
if x <= mid:
s += self.query(p*2,l,mid,x,y)
if mid < y:
s += self.query(p*2+1,mid+1,r,x,y)
return s
class Solution:
def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
s = list(accumulate(nums,initial=0))
hashes = s + [ x-lower for x in s] + [ x-upper for x in s]
hashes = sorted(list(set(hashes)))
# 生成前缀和,问题转化为,对于每个j,找左边的i,判断 s[j]-upper<=s[i]<=s[j]-lower,统计这些i的数量
# 把所有前缀和数组中的数字插入线段树,并对这些数字划分区间,线段树维护当前区间数字数量,
# 所以需要对这些数字都散列化
tree_size = len(hashes)
tree = IntervalTree(tree_size)
cnt = 0
for i in s:
x = bisect_left(hashes,i-upper)
y = bisect_left(hashes,i-lower)
j = bisect_left(hashes,i)
c = tree.query(1,1,tree_size, x+1,y+1)
# print(x,y,j,c)
cnt += c
tree.insert(1,1,tree_size,j+1)
return cnt
链接: 3732. 我的日程安排表 III
线段树
这题由于只能在线做,不能做离散化,因此需要用字典维护线段端点,实现动态开点。
参考链接: [LeetCode解题报告] 732. 我的日程安排表 III
class IntervalTree:
def __init__(self):
self.interval_tree = collections.defaultdict(int)
self.lazys = collections.defaultdict(int)
def give_lay_to_son(self,p,l,r):
interval_tree = self.interval_tree
lazys = self.lazys
if lazys[p] == 0:
return
mid = (l+r)//2
interval_tree[p*2] += lazys[p]
interval_tree[p*2+1] += lazys[p]
lazys[p*2] += lazys[p]
lazys[p*2+1] += lazys[p]
lazys[p] = 0
def add(self,p,l,r,x,y,val):
"""
把[x,y]区域全+val
"""
if r < x or y < l: # 这里不加就会TLE
return
interval_tree = self.interval_tree
lazys = self.lazys
if x <= l and r<=y:
interval_tree[p] += val
lazys[p] += val
return
self.give_lay_to_son(p,l,r) #这题由于永远不会询问子区间,所以其实可以不向下give,直接在return的时候+lazy,会快一点。
mid = (l+r)//2
if x <= mid:
self.add(p*2,l,mid,x,y,val)
if mid < y:
self.add(p*2+1,mid+1,r,x,y,val)
interval_tree[p] = max(interval_tree[p*2], interval_tree[p*2+1])
def query(self,p,l,r,x,y):
"""
查找x,y区间的最大值
"""
if x<=l and r<=y:
return self.interval_tree[p]
self.give_lay_to_son(p,l,r)
mid = (l+r)//2
s = 0
if x <= mid:
s = max(s,self.query(p*2,l,mid,x,y))
if mid < y:
s = max(s,self.query(p*2+1,mid+1,r,x,y))
return s
class MyCalendarThree:
def __init__(self):
self.tree = IntervalTree()
def book(self, start: int, end: int) -> int:
self.tree.add(1,1,10**9+1,start,end-1,1)
return self.tree.interval_tree[1]
链接: 6206. 最长递增子序列 II
线段树
这题算是打开了LIS一个新的优化思路,传统N方DP由于单次查询是On因此是N方,用线段树可以把单次查询降低到lg
参考链接: [LeetCode周赛复盘] 第 310 场周赛20220911
class IntervalTree:
def __init__(self, size,nums=None):
self.size = size
self.nums = nums
self.interval_tree = [0]*(size*4)
def update_point(self,p,l,r,index,val):
if index < l or r < index:
return
interval_tree = self.interval_tree
interval_tree[p] =max(interval_tree[p],val)
if l == r:
return
mid = (l+r)//2
if index <= mid:
self.update_point(p*2,l,mid,index,val)
else:
self.update_point(p*2+1,mid+1,r,index,val)
def query(self,p,l,r,x,y):
"""
查找x,y区间的最大值 """
if y < l or r < x:
return 0
if x<=l and r<=y:
return self.interval_tree[p]
mid = (l+r)//2
s = 0
if x <= mid:
s = max(s,self.query(p*2,l,mid,x,y))
if mid < y:
s = max(s,self.query(p*2+1,mid+1,r,x,y))
return s
class Solution:
def lengthOfLIS(self, nums: List[int], k: int) -> int:
n = len(nums)
mx = max(nums)
tree = IntervalTree(mx)
ans = 0
for i in range(n):
v = nums[i]
l = max(0,v-k)
r = max(0,v-1)
ret = tree.query(1,1,mx,l,r)+1
tree.update_point(1,1,mx,v,ret)
ans = max(ans,ret)
return ans
链接: 6358. 更新数组后处理求和查询
链接: P3870 [TJOI2009] 开关
class IntervalTree:
def __init__(self, size):
self.size = size
self.interval_tree = [0 for _ in range(size*4)]
self.lazys = [0 for _ in range(size*4)]
def give_lay_to_son(self,p,l,r):
interval_tree = self.interval_tree
lazys = self.lazys
if lazys[p] == 0:
return
mid = (l+r)//2
interval_tree[p*2] = mid - l + 1 - interval_tree[p*2]
interval_tree[p*2+1] = r - mid - interval_tree[p*2+1]
lazys[p*2] ^= 1
lazys[p*2+1] ^=1
lazys[p] = 0
def update(self,p,l,r,x,y,val):
"""
把[x,y]区域全变成val
"""
if y < l or r < x:
return
interval_tree = self.interval_tree
lazys = self.lazys
if x <= l and r<=y:
interval_tree[p] = r-l+1-interval_tree[p]
lazys[p] ^= 1
return
self.give_lay_to_son(p,l,r)
mid = (l+r)//2
if x <= mid:
self.update(p*2,l,mid,x,y,val)
if mid < y:
self.update(p*2+1,mid+1,r,x,y,val)
interval_tree[p] = interval_tree[p*2]+ interval_tree[p*2+1]
def query(self,p,l,r,x,y):
"""
区间求和 """
if y < l or r < x:
return 0
if x<=l and r<=y:
return self.interval_tree[p]
self.give_lay_to_son(p,l,r)
mid = (l+r)//2
s = 0
if x <= mid:
s += self.query(p*2,l,mid,x,y)
if mid < y:
s += self.query(p*2+1,mid+1,r,x,y)
return s
class Solution:
def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
n = len(nums1)
s = sum(nums2)
tree = IntervalTree(n)
for i,v in enumerate(nums1,start=1):
if v:
tree.update(1,1,n,i,i,1)
ans = []
for op,l,r in queries:
if op == 1:
tree.update(1,1,n,l+1,r+1,1)
elif op == 2:
s += l*tree.query(1,1,n,1,n)
else:
ans.append(s)
return ans