[python刷题模板] 线段树

[python刷题模板] 线段树

    • 一、 算法&数据结构
      • 1. 描述
      • 2. 复杂度分析
      • 3. 常见应用
      • 4. 常用优化
    • 二、 模板代码
      • 1. 区间更新,区间询问最大值(IUIQ)
      • 2. 矩形面积并
      • 3.单点更新,区间求和
      • 4.单点更新,区间求和
      • 5.区间更新,区间查询,无法离散化,动态开点。
      • 6.单点更新,区间查询最大值。
      • 7. 区间01翻转(异或),区间查询1的个数。
    • 三、其他
    • 四、更多例题

一、 算法&数据结构

1. 描述

线段树通常用来维护区间询问,通过二分的形式对数组分块,维护一个个小块上的属性(极值、求和等),用以实现O(logn)时间的查询和更新

2. 复杂度分析

  1. 查询query, O(log2n)
  2. 更新update,O(log2n)

3. 常见应用

  1. 单点更新,区间求极值(最入门)
  2. 单点更新,区间求和(稍复杂)
  3. 区间更新,单点或区间求值,如果卡常数需要用到lazytag

4. 常用优化

  1. 设置lazytag,用于区间更新,判断全包含时,不再向下递归,一般卡常数可以搞,每次update和query都需要give_lazy_to_son
  2. 离散化,因为线段树维护的是整数,如果题目给的是实数(浮点、复数、过大的数),那么可以把数据离散化,毕竟数组长度一般不会太大。

二、 模板代码

1. 区间更新,区间询问最大值(IUIQ)

例题: 699. 掉落的方块
用线段树维护x轴上每个线段(区间)的最大高度

class IntervalTree:
    def __init__(self, size):
        self.size = size
        self.interval_tree = [0 for _ in range(size*4)]
        self.lazys = [0 for _ in range(size*4)]

    def give_lay_to_son(self,p,l,r):
        interval_tree = self.interval_tree
        lazys = self.lazys
        if lazys[p] == 0:
            return
        mid = (l+r)//2
        interval_tree[p*2] = lazys[p]
        interval_tree[p*2+1] = lazys[p]
        lazys[p*2] = lazys[p]
        lazys[p*2+1] = lazys[p]
        lazys[p] = 0
        
    def update(self,p,l,r,x,y,val):
        """
        把[x,y]区域全变成val
        """
        if y < l or r < x:
            return 
        interval_tree = self.interval_tree    
        lazys = self.lazys        
        if x <= l and r<=y:
            interval_tree[p] = val
            lazys[p] = val
            return
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        if x <= mid:
            self.update(p*2,l,mid,x,y,val)
        if mid < y:
            self.update(p*2+1,mid+1,r,x,y,val)
        interval_tree[p] = max(interval_tree[p*2], interval_tree[p*2+1])
    
    def query(self,p,l,r,x,y):
        """
        查找x,y区间的最大值        """        
        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(p*2,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(p*2+1,mid+1,r,x,y))
        return s

class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        n = len(positions)
        hashes = [left for left,_ in positions] + [left+side for left,side in positions] 
        hashes = sorted(list(set(hashes)))
        # 用线段树维护x轴区间最大值,记录每个点的高度:比如[1,2]这个方块,会使线段[1,2]闭区间这个线段上的每个高度都变成2
        # 落下一个新方块时,查询它的底边所在线段[x,y]的最大高度h,这个方块会落在这个高度h,把新高度h+side插入线段树[x,y]的部分
        # 每次插入结束,树根存的高度就是当前最大高度
        # 由于数据范围大1 <= lefti <= 108,需要散列化
        # 散列化的值有left和right(线段短点)
        # print(hashes)
        tree_size = len(hashes)
        tree = IntervalTree(tree_size)
        heights = []
        for left,d in positions:
            right = left + d 
            l = bisect_left(hashes,left)
            r = bisect_left(hashes,right)
            h = tree.query(1,1,tree_size,l+1,r)
            tree.update(1,1,tree_size,l+1,r,h+d)
            heights.append(tree.interval_tree[1])
        return heights

2. 矩形面积并

链接: 850. 矩形面积 II

线段树经典案例,涉及离散化扫描线,细节很多,非常难写

class IntervalTreeNode:
    def __init__(self, len,cover):
        self.len = len
        self.cover = cover
class IntervalTree:
    def __init__(self, size,ys=None):
        self.size = size
        # self.interval_tree = [IntervalTreeNode(0,0)]*(size*4)     ## 这个地方wa了很久,不能这么写,一直更新一个实例
        self.interval_tree = [IntervalTreeNode(0,0) for _ in range(size*4)]   
        self.ys=ys

    def update_from_son(self,p,l,r):
        interval_tree = self.interval_tree
        pn = interval_tree[p]
        if pn.cover > 0:
            pn.len = self.ys[r-1]-self.ys[l-1]
        else:
            if l+1 ==r:
                pn.len = 0
            else:
                pn.len = interval_tree[p*2].len + interval_tree[p*2+1].len 

    def insert(self,p,l,r,x,y,cover):        
        if y < l or r < x:
            return
        interval_tree = self.interval_tree        
        if x<=l and r<=y:
            interval_tree[p].cover += cover
            self.update_from_son(p,l,r)
            return
        mid = (l+r)//2
        if x < mid:
            self.insert(p*2,l,mid,x,y,cover)
        if y > mid:
            self.insert(p*2+1,mid,r,x,y,cover)
        self.update_from_son(p,l,r)
            
class LineY:
    def __init__(self,x,y1,y2,cover):
        self.x = x
        self.y1 = y1
        self.y2 = y2
        self.cover = cover
        
class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        lines = []  # 所有竖线线段
        ys = set()  # 离散化
        for x1,y1,x2,y2 in rectangles:
            lines.append(LineY(x1,y1,y2,1))
            lines.append(LineY(x2,y1,y2,-1))
            ys.add(y1)
            ys.add(y2)
        lines.sort(key=lambda x:x.x)
        line_count = len(lines)
        ys = list(ys)
        ys.sort()

        interval_tree = IntervalTree(line_count,ys=ys)

        ans = 0
        mod = int(1e9+7)
        for i in range(0,line_count):
            line = lines[i]
            # print(line.x,line.y1,line.y2,line.cover)
            y1 = bisect_left(ys,line.y1) 
            y2 = bisect_left(ys,line.y2)
            # if y1==y2:
            #     continue
            if i >0:
                ans += (line.x-lines[i-1].x) * interval_tree.interval_tree[1].len
                ans %= mod
            interval_tree.insert(1,1,len(ys),y1+1,y2+1,line.cover)
        
        return ans

3.单点更新,区间求和

链接: 307. 区域和检索 - 数组可修改

线段树经典案例,比区间求极值麻烦一点点

class IntervalTree:
    def __init__(self, size,nums=None):
        self.size = size
        self.nums = nums
        self.interval_tree = [0]*(size*4)
        if nums:
            self.build_tree(1,1,size)

    def build_tree(self,p,l,r):
        interval_tree = self.interval_tree
        nums = self.nums
        if l == r:
            interval_tree[p] = nums[l-1]
            return
        mid = (l+r)//2
        self.build_tree(p*2,l,mid)
        self.build_tree(p*2+1,mid+1,r)
        interval_tree[p] = interval_tree[p*2]+interval_tree[p*2+1]
    
    def add_point(self,p,l,r,index,add):        
        if index < l or r < index:
            return 
        interval_tree = self.interval_tree
        interval_tree[p] += add
        if l == r:
            return
        mid = (l+r)//2
        if index <= mid:
            self.add_point(p*2,l,mid,index,add)
        else:
            self.add_point(p*2+1,mid+1,r,index,add)
    
    def sum_interval(self,p,l,r,x,y):        
        if y < l or r < x:
            return 0
        interval_tree = self.interval_tree
        if x<=l and r<=y:
            return interval_tree[p]
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.sum_interval(p*2,l,mid,x,y)
        if mid < y:
            s += self.sum_interval(p*2+1,mid+1,r,x,y)
        return s

class NumArray:

    def __init__(self, nums: List[int]):
        self.size = len(nums)
        self.nums = nums
        self.interval_tree = IntervalTree(self.size ,nums)
        
    def update(self, index: int, val: int) -> None:         
        add = val - self.nums[index]
        self.nums[index] = val
        self.interval_tree.add_point(1,1,self.size,index+1,add)
        
    def sumRange(self, left: int, right: int) -> int:        
        return self.interval_tree.sum_interval(1,1,self.size,left+1,right+1)         

4.单点更新,区间求和

链接: 327. 区间和的个数

线段树这题麻烦一点,求区间内数字数量,每个数初始化为0,插入时候+1,计数转化为求和
参考链接: [LeetCode解题报告]327. 区间和的个数

class IntervalTree:
    def __init__(self, size):
        self.size = size
        self.interval_tree = [0 for _ in range(size*4)]

    def insert(self,p,l,r,index):
        if index < l or r < index:
            return 
        interval_tree = self.interval_tree        
        if l == r:
            interval_tree[p] += 1
            return
        mid = (l+r)//2
        if index <= mid:
            self.insert(p*2,l,mid,index)
        else:
            self.insert(p*2+1,mid+1,r,index)
        interval_tree[p] = interval_tree[p*2]+interval_tree[p*2+1]       
    
    def query(self,p,l,r,x,y):
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.query(p*2,l,mid,x,y)
        if mid < y:
            s += self.query(p*2+1,mid+1,r,x,y)
        return s

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:         
        s = list(accumulate(nums,initial=0))
        hashes = s + [ x-lower for x in s] + [ x-upper for x in s]
        hashes = sorted(list(set(hashes)))
        # 生成前缀和,问题转化为,对于每个j,找左边的i,判断 s[j]-upper<=s[i]<=s[j]-lower,统计这些i的数量
        # 把所有前缀和数组中的数字插入线段树,并对这些数字划分区间,线段树维护当前区间数字数量,
        # 所以需要对这些数字都散列化
        tree_size = len(hashes)
        tree = IntervalTree(tree_size)
        cnt = 0
        for i in s:
            x = bisect_left(hashes,i-upper)
            y = bisect_left(hashes,i-lower)
            j = bisect_left(hashes,i)
            c = tree.query(1,1,tree_size, x+1,y+1)
            # print(x,y,j,c)
            cnt += c
            tree.insert(1,1,tree_size,j+1)

        return cnt

5.区间更新,区间查询,无法离散化,动态开点。

链接: 3732. 我的日程安排表 III

线段树这题由于只能在线做,不能做离散化,因此需要用字典维护线段端点,实现动态开点。
参考链接: [LeetCode解题报告] 732. 我的日程安排表 III

class IntervalTree:
    def __init__(self):
        self.interval_tree = collections.defaultdict(int)
        self.lazys = collections.defaultdict(int)        

    def give_lay_to_son(self,p,l,r):
        interval_tree = self.interval_tree
        lazys = self.lazys
        if lazys[p] == 0:
            return
        mid = (l+r)//2
        interval_tree[p*2] += lazys[p]
        interval_tree[p*2+1] += lazys[p]
        lazys[p*2] += lazys[p]
        lazys[p*2+1] += lazys[p]
        lazys[p] = 0
        
    def add(self,p,l,r,x,y,val):
        """
        把[x,y]区域全+val
        """
        if r < x or y < l:  # 这里不加就会TLE
            return 
        interval_tree = self.interval_tree    
        lazys = self.lazys        
        if x <= l and r<=y:
            interval_tree[p] += val
            lazys[p] += val
            return
        self.give_lay_to_son(p,l,r)  #这题由于永远不会询问子区间,所以其实可以不向下give,直接在return的时候+lazy,会快一点。
        mid = (l+r)//2
        if x <= mid:
            self.add(p*2,l,mid,x,y,val)
        if mid < y:
            self.add(p*2+1,mid+1,r,x,y,val)
        interval_tree[p] = max(interval_tree[p*2], interval_tree[p*2+1]) 
    
    def query(self,p,l,r,x,y):
        """
        查找x,y区间的最大值
        """        
        if x<=l and r<=y:
            return self.interval_tree[p]
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(p*2,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(p*2+1,mid+1,r,x,y))
        return s

class MyCalendarThree:

    def __init__(self):
        self.tree = IntervalTree()
   
    def book(self, start: int, end: int) -> int:
        self.tree.add(1,1,10**9+1,start,end-1,1)
        return self.tree.interval_tree[1]

6.单点更新,区间查询最大值。

链接: 6206. 最长递增子序列 II

线段树这题算是打开了LIS一个新的优化思路,传统N方DP由于单次查询是On因此是N方,用线段树可以把单次查询降低到lg
参考链接: [LeetCode周赛复盘] 第 310 场周赛20220911

class IntervalTree:
    def __init__(self, size,nums=None):
        self.size = size
        self.nums = nums
        self.interval_tree = [0]*(size*4)

    def update_point(self,p,l,r,index,val):        
        if index < l or r < index:
            return 
        interval_tree = self.interval_tree
        interval_tree[p] =max(interval_tree[p],val)
        if l == r:
            return
        mid = (l+r)//2
        if index <= mid:
            self.update_point(p*2,l,mid,index,val)
        else:
            self.update_point(p*2+1,mid+1,r,index,val)
    
    def query(self,p,l,r,x,y):
        """
        查找x,y区间的最大值        """        
        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(p*2,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(p*2+1,mid+1,r,x,y))
        return s

    
class Solution:
    def lengthOfLIS(self, nums: List[int], k: int) -> int:
        n = len(nums)
        mx = max(nums)
        tree = IntervalTree(mx)
        ans = 0
        for i in range(n):
            v = nums[i]
            l = max(0,v-k)
            r = max(0,v-1)
            ret = tree.query(1,1,mx,l,r)+1
            tree.update_point(1,1,mx,v,ret)
            ans = max(ans,ret)

        return ans

7. 区间01翻转(异或),区间查询1的个数。

链接: 6358. 更新数组后处理求和查询
链接: P3870 [TJOI2009] 开关

  • 注意处理时,lazy为1才需要向下处理。
  • lazy^=1;然后重新计算区间1的个数,其实就是取反:长度-原个数。
class IntervalTree:
    def __init__(self, size):
        self.size = size
        self.interval_tree = [0 for _ in range(size*4)]
        self.lazys = [0 for _ in range(size*4)]

    def give_lay_to_son(self,p,l,r):
        interval_tree = self.interval_tree
        lazys = self.lazys
        if lazys[p] == 0:
            return
        mid = (l+r)//2
        interval_tree[p*2] = mid - l + 1 -  interval_tree[p*2]
        interval_tree[p*2+1] = r - mid - interval_tree[p*2+1]
        lazys[p*2] ^= 1
        lazys[p*2+1] ^=1
        lazys[p] = 0
        
    def update(self,p,l,r,x,y,val):
        """
        把[x,y]区域全变成val
        """
        if y < l or r < x:
            return 
        interval_tree = self.interval_tree    
        lazys = self.lazys        
        if x <= l and r<=y:
            interval_tree[p] = r-l+1-interval_tree[p]
            lazys[p] ^= 1
            return
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        if x <= mid:
            self.update(p*2,l,mid,x,y,val)
        if mid < y:
            self.update(p*2+1,mid+1,r,x,y,val)
        interval_tree[p] = interval_tree[p*2]+ interval_tree[p*2+1]    

    
    def query(self,p,l,r,x,y):
        """
        区间求和      """        
        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.query(p*2,l,mid,x,y)
        if mid < y:
            s += self.query(p*2+1,mid+1,r,x,y)
        return s
    
class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        n = len(nums1)
        s = sum(nums2)
        tree = IntervalTree(n)
        for i,v in enumerate(nums1,start=1):
            if v:
                tree.update(1,1,n,i,i,1)
        ans = []
        for op,l,r in queries:
            if op == 1:
                tree.update(1,1,n,l+1,r+1,1)
            elif op == 2:
                s += l*tree.query(1,1,n,1,n)
            else:
                ans.append(s)
        return ans

三、其他

  1. 如果还是卡常数,有些区间问题可以转化为树状数组,常数小,代码短,不过真的很难理解,还是线段树好写。遇到就套板吧。

四、更多例题

  • 729. 我的日程安排表 I 求区间极值
  • 731. 我的日程安排表 II 求区间极值

你可能感兴趣的:(python刷题模板,python,算法,数据结构)