线段树算法是一种快速查询一段区间内的信息的算法,由于其实现简单,所以广泛应用于程序设计竞赛中。
线段树是一棵完美二叉树,即所有的叶子节点的深度均相同, 并且所有的非叶子节点都有两个子节点。每个节点维护一个区间,这个区间为父节点二分后的子区间,根节点维护整个区间,叶子节点维护单个元素,当元素个数为n
时,对区间的操作都可以在O(log n)
的时间内完成,因为此时树的深度为log2 n + 1
,每次操作只需从叶子节点开始,往上更新至根节点,每层只需更新相关的一个区间即可,操作次数log2 n + 1
,即在O(log n)
的时间内可完成。
线段树可以提供不同的功能,例如最常见的求区间内的最大最小值和求区间内的和,还有其他类似的功能,实现思路基本相同。
O(logn)
查询的区间 和 线段树节点区间 相等 -> 直接返回
查询的区间 被 线段树节点区间 包含-> 递归向下搜索左右子树
查询的区间 和 线段树节点区间 不相交 -> 结束
查询的区间 和 线段树节点区间 相交且不相等 -> 分裂查询区间
O(n)
自上而下递归分裂
自下而上回溯更新
O(logn)
自上而下递归查询
自下而上回溯更新
SegmentTree
的 Python
实现:
# -*- coding: utf-8 -*-
"""
Description: 线段树
构建一个线段树,线段树的非叶子节点存储的是区间内的最大值
"""
import math
class SegmentTree(object):
def __init__(self, arr):
self.size = len(arr)
# approximate the overall size of segment tree with array
self.st = [0] * (4 * self.size)
# build segment tree
self.build(0, 0, self.size - 1, arr)
def left(self, idx):
"""求索引为 idx 的节点的左子树的索引"""
return idx * 2 + 1
def right(self, idx):
"""求索引为 idx 的节点的右子树的索引"""
return idx * 2 + 2
def build(self, idx, left_idx, right_idx, arr):
if left_idx == right_idx:
self.st[idx] = arr[left_idx]
else:
mid = (left_idx + right_idx) // 2
self.build(self.left(idx), left_idx, mid, arr)
self.build(self.right(idx), mid + 1, right_idx, arr)
# 非叶子节点中存储的是区间内的最大值
self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)])
def update(self, a, b, value):
"""将[a, b] 区间被的值更新为 value"""
return self.update_recursive(0, 0, self.size - 1, a, b, value)
def update_recursive(self, idx, left_idx, right_idx, a, b, value):
if right_idx < a or left_idx > b:
return True
if left_idx == right_idx:
self.st[idx] = value
return True
mid = (left_idx + right_idx) // 2
self.update_recursive(self.left(idx), left_idx, mid, a, b, value)
self.update_recursive(self.right(idx), mid + 1, right_idx, a, b, value)
self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)])
return True
def query(self, a, b):
return self.query_recursive(0, 0, self.size - 1, a, b)
def query_recursive(self, idx, left_idx, right_idx, a, b):
if right_idx < a or left_idx > b:
return -math.inf
if left_idx >= a and right_idx <= b:
return self.st[idx]
mid = (left_idx + right_idx) // 2
q1 = self.query_recursive(self.left(idx), left_idx, mid, a, b)
q2 = self.query_recursive(self.right(idx), mid + 1, right_idx, a, b)
return max(q1, q2)
def show_data(self):
show_list = []
for i in range(0, self.size):
show_list += [self.query(i, i)]
print(show_list)
if __name__ == '__main__':
array = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9]
print(len(array))
segment_tree = SegmentTree(array)
print(segment_tree.query(4, 6))
print(segment_tree.query(7, 14))
print(segment_tree.query(-3, 2))
segment_tree.show_data()
segment_tree.update(1, 3, 111)
segment_tree.show_data()
segment_tree.update(7, 8, 234)
segment_tree.show_data()