“线段树”实现了高效的“数组区间查询”与“数组区间更新”
“线段树”(segment tree)又称“区间树”,是一个高级数据结构,应用的对象是“数组”。
先来看 LeetCode 第 303 题和 LeetCode 第 307 题。
LeetCode 第 303 题:区域和检索 - 数组不可变
传送门:303. 区域和检索 - 数组不可变
给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
示例:
给定 nums = [-2, 0, 3, -5, 2, -1],求和函数为 sumRange() sumRange(0, 2) -> 1 sumRange(2, 5) -> -1 sumRange(0, 5) -> -3
说明:
- 你可以假设数组不可变。
- 会多次调用 sumRange 方法。
思路:我们可以设计一个前缀和数组 cumsum ,在查询的时候,只用 时间复杂度,不过在数组元素有频繁更新的时候,会导致性能下降,即这种方式不适用于 LeetCode 第 307 题。
Python 代码:
class NumArray:
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.size = len(nums)
if self.size > 0:
self.cumsum = [0 for _ in range(self.size + 1)]
self.cumsum[1] = nums[0]
for i in range(2, len(nums) + 1):
self.cumsum[i] = self.cumsum[i - 1] + nums[i - 1]
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
if self.size > 0:
return self.cumsum[j + 1] - self.cumsum[i]
return 0
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(i,j)
Java 代码:
public class NumArray {
// cumsum 实现
// [1,2,3,4]
// [1,3,6,10]
private int[] nums;
public NumArray(int[] nums) {
this.nums = nums;
for (int i = 1; i < nums.length; i++) {
nums[i] = nums[i] + nums[i - 1];
}
}
public int sumRange(int i, int j) {
return nums[j] - (i - 1 < 0 ? 0 : nums[i - 1]);
}
public static void main(String[] args) {
int[] nums = {1, 2, 3, 4};
NumArray numArray = new NumArray(nums);
int result = numArray.sumRange(2, 3);
System.out.println(result);
}
}
LeetCode 第 307 题:区域和检索 - 数组可修改
传送门:307. 区域和检索 - 数组可修改。
给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。
示例:
Given nums = [1, 3, 5] sumRange(0, 2) -> 9 update(1, 2) sumRange(0, 2) -> 8
说明:
- 数组仅可以在 update 函数下进行修改。
- 你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。
如果我们不使用任何数据结构,每次求“区间和”,都会遍历这个区间里的所有元素。如果区间里包含的元素很多,并且查询次数很频繁,时间复杂度就接近 。如果我们使用线段树,就可以把时间复杂度降低到 。
这里要注意的是“线段树”解决的区间问题不涉及“添加”与“删除”操作,即“CURD”,我们只负责“U” 和 “R”。
使用“遍历”与使用“线段树”对于“区间更新”与“区间查询”操作的复杂度
遍历 | 线段树 | |
---|---|---|
区间查询 | ||
区间更新 |
说明:由于我们的线段树(区间树)采用平衡二叉树实现,因此 中的对数函数以 为底,即 。
“线段树”可以使用数组表示
以前我们学习过“堆”,并且知道“堆”是一棵“完全二叉树”,因此“堆”可以用数组表示,基于此,我们很自然地想到可以用数组表示“线段树”。
完全二叉树:除了最后一层以外,其余各层的结点数达到最大,并且最后一层所有的结点都连续地、集中地存储在最左边。
线段树虽然不是完全二叉树,但线段树是平衡二叉树,依然也可以用数组表示。
“自顶向下”递归构建线段树
首先看看“线段树”长什么样。
线段树是一种二叉树结构,不过在实现的时候,可以使用数组实现,这一点和优先队列是一致的。
需要多少空间
“线段树”的一个经典实现是从上到下递归构建,这一点很像根据员工人数来定领导的人数,设置多少领导的个数就要看员工有多少人了。再想一想,我们在开篇对于线段树的介绍,线段树适合支持的操作是“查询”和“更新”,不适用于“添加”和“删除”。
下面以“员工和领导”为例,讲解从上到下逐步构建线段树的步骤:我们首先要解决的问题是“一共要设置多少领导”,我们宁可有一些位置没有人坐,也要让所有的人都坐下,因此我们在做估计的时候只会放大。
比较极端的一种情况:
我们假设员工的人数为 ,我们也可以认为这就是是我们问题的规模,如果 可以表示成 (例如,、、), 是正整数,这种情况下,组织出来的数一定是满二叉树(除叶子结点外的所有结点均有两个子结点)。那么要设置的领导的人数就是 ,于是我们设置 长度的数组就一定可以容纳下这么多领导和员工。
下面考虑一种糟糕的情况,例如我们的员工人数刚刚好是 次方幂多 ,例如 、、,我们的思路很简单,看看可不可以转化成上面那种情况,因为满二叉树一定是完全二叉树,我们就可以使用数组来表示),原则仍然是放大,例如: “ 放大到 ”,“ 放大到 ”, 但是我们不这么做,我们做得再“过分”一点,我们放大到 倍,它一定比大于问题规模 的最小 次方幂还大,此时为了组织成完全二叉树,将问题规模放大到 ,由上面的分析,我们知道还要给领导准备 把椅子,那么总共领导和员工就要准备 把椅子。
线段树是一颗平衡二叉树
线段树是一棵平衡二叉树(最大深度和最小深度的差距最多为 )。平衡二叉树不会像二分搜索树那样变成一个链表,并且平衡二叉树也可以用数组来表示。
我们还要清楚一点,我们上面只是为了分析出,我们要处理问题规模为 的问题的时候,要准备多少空间,我们分析出当员工数为 的时候,最多分配到 把椅子就能把领导和员工都装下了。下面展示一些图来表示这些情况,特别注意,我们分析的时候是从下到上的,但是实际上,我们拿到问题规模以后的划分却是从上到下的。我们的确浪费了一些空间,甚至有的时候我们浪费了很多空间。
根据上面的讨论,我们可以写出线段树的框架:
Python 代码:
class SegmentTree:
def __init__(self, arr):
self.data = arr
# 开 4 倍大小的空间
self.tree = [None for _ in range(4 * len(arr))]
def get_size(self):
return len(self.data)
def get(self, index):
if index < 0 or index >= len(self.data):
raise Exception("Index is illegal.")
return self.data[index]
def __left_child(self, index):
return 2 * index + 1
def __right_child(self, index):
return 2 * index + 2
Java 代码:
public class SegmentTree {
// 一共要给领导和员工准备的椅子,是我们要构建的辅助数据结构
private E[] tree;
// 原始的领导和员工数据,这是一个副本
private E[] data;
public SegmentTree(E[] arr) {
this.data = data;
// 数组初始化
data = (E[]) new Object[arr.length];
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
tree = (E[]) new Object[4 * arr.length];
}
public int getSize() {
return data.length;
}
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
return data[index];
}
/**
* 返回完全二叉树的数组表示中,索引所表示的元素的左孩子结点的索引
* 注意:索引编号从 0 开始
*
* @param 线段树的某个结点的索引
* @return 传入的结点的左结点的索引
*/
public int leftChild(int index) {
return 2 * index + 1;
}
/**
* 返回完全二叉树的数组表示中,索引所表示的元素的左孩子结点的索引
* 注意:索引编号从 0 开始
*
* @param 线段树的某个结点的索引
* @return 传入的结点的右结点的索引
*/
public int rightChild(int index) {
return 2 * index + 2;
}
}
根据原始数组创建线段树
这一节的目标是:我们把员工的信息输入一棵线段树,让这棵线段树组织出领导架构。即已知 data 数组,要把 tree 数组构建出来。
分析递归结构
重点体会:二叉树每做一次分支都是“平均地”一分为二进行的。
递归到底的时候,这个区间只有 个元素。
设计私有函数,我们需要考虑 个变量:
1、我们要创建的线段树的根结点的索引,这个索引是线段树的索引;
2、对于线段树结点所要表示的 data 数组的区间的左端点是什么;
3、对于线段树结点所要表示的 data 数组的区间的右端点是什么。
Java 代码:
buildSegmentTree(0, 0, arr.length - 1);
Java 代码:关键代码
/**
* 这个递归方法的描述一定要非常清楚:
* 画出 tree 树中以 treeIndex 为根的,统计 data 数组中 [l,r] 区间中的元素
* 这个方法的实现引入了一个 merge 接口,使得外部可以传入一个方法,方法是如何实现的是根据业务而定
* 核心代码只有几行,这里关键还是在于递归方法
*
* @param treeIndex 我们要创建的线段树根结点所在的索引,treeIndex 是 tree 的索引
* @param l 对于 treeIndex 结点所要表示的 data 区间端点是什么,l 是 data 的索引
* @param r 对于 treeIndex 结点所要表示的 data 区间端点是什么,r 是 data 的索引
*/
private void buildSegmentTree(int treeIndex, int l, int r) {
// 考虑递归到底的情况
if (l == r) {
// 平衡二叉树叶子结点的赋值就是靠这句话形成的
tree[treeIndex] = data[l]; // data[r],此时对应叶子结点的情况
return;// return 不能忘记
}
int mid = l + (r - l) / 2;
int leftChild = leftChild(treeIndex);
int rightChild = rightChild(treeIndex);
// 假设左边右边都处理完了以后,再处理自己
// 这一点基于,高层信息的构建依赖底层信息的构建
// 这个递归的过程我们可以通过画图来理解
// 仔细阅读下面的这三行代码,是不是像极了二分搜索树的后序遍历,我们先处理了左右孩子结点,最后处理自己
buildSegmentTree(leftChild, l, mid);
buildSegmentTree(rightChild, mid + 1, r);
// 注意:merge 的实现根据业务而定
tree[treeIndex] = merge.merge(tree[leftChild], tree[rightChild]);
}
Merge 接口的设计,这里使用传入对象的方式实现了方法传递,是 Command 设计模式。
Java 代码:
public interface Merge {
E merge(E e1, E e2);
}
给 SegmentTree
覆盖 toString
方法,用于打印线段树表示的数组,以便执行测试用例。
@Override
public String toString() {
StringBuilder s = new StringBuilder();
s.append("[");
for (int i = 0; i < tree.length; i++) {
if(tree[i] == null){
s.append("NULL");
}else{
s.append(tree[i]);
}
s.append(",");
}
s.append("]");
return s.toString();
}
4、测试方法
public class Main {
public static void main(String[] args) {
Integer[] nums = {0, -1, 2, 4, 2};
SegmentTree segmentTree = new SegmentTree(nums, new Merge() {
@Override
public Integer merge(Integer e1, Integer e2) {
return e1 + e2;
}
});
System.out.println(segmentTree);
}
}
区间查询
通过编写二分搜索树的经验,我们知道,一些递归的写法通常要写一个辅助函数,在这个辅助函数里完成递归调用。那么对于这个问题中,辅助函数的设计就显得很关键了。
// 在一棵子树里做区间查询,dataL 和 dataR 都是原始数组的索引
public E query(int dataL, int dataR) {
if (dataL < 0 || dataL >= data.length || dataR < 0 || dataR >= data.length || dataL > dataR) {
throw new IllegalArgumentException("Index is illegal.");
}
// data.length - 1 边界不能弄错
return query(0, 0, data.length - 1, dataL, dataR);
}
在这个辅助函数的实现过程中,可以画一张图来展现一下具体的计算过程。
体会下面这个过程:
我们总是自上而下,从根结点开始向下查询,最坏情况下,才会查询到叶子结点。
Java 代码:
// 这是一个递归调用的辅助方法,应该定义成私有方法
private E query(int treeIndex, int l, int r, int dataL, int dataR) {
if (l == dataL && r == dataR) {
// 这里一定不要犯晕,看图说话
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftChildIndex = leftChild(treeIndex);
int rightChildIndex = rightChild(treeIndex);
// 画个示意图就能清楚自己的逻辑是怎样的
if (dataR <= mid) {
return query(leftChildIndex, l, mid, dataL, dataR);
}
if (dataL >= mid + 1) {
return query(rightChildIndex, mid + 1, r, dataL, dataR);
}
// 横跨两边的时候,先算算左边,再算算右边
E leftResult = query(leftChildIndex, l, mid, dataL, mid);
E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, dataR);
return merge.merge(leftResult, rightResult);
}
LeetCode 第 303 题:区域和检索 - 数组不可变
传送门:303. 区域和检索 - 数组不可变
给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
示例:
给定 nums = [-2, 0, 3, -5, 2, -1],求和函数为 sumRange() sumRange(0, 2) -> 1 sumRange(2, 5) -> -1 sumRange(0, 5) -> -3
说明:
- 你可以假设数组不可变。
- 会多次调用 sumRange 方法。
思路2:基于线段树(区间树)的实现。
Python 代码:
class NumArray:
class SegmentTree:
def __init__(self, arr, merge):
self.data = arr
# 开 4 倍大小的空间
self.tree = [None for _ in range(4 * len(arr))]
if not hasattr(merge, '__call__'):
raise Exception('不是函数对象')
self.merge = merge
self.__build_segment_tree(0, 0, len(self.data) - 1)
def get_size(self):
return len(self.data)
def get(self, index):
if index < 0 or index >= len(self.data):
raise Exception("Index is illegal.")
return self.data[index]
def __left_child(self, index):
return 2 * index + 1
def __right_child(self, index):
return 2 * index + 2
def __build_segment_tree(self, tree_index, data_l, data_r):
# 区间只有 1 个数的时候,线段树的值,就是数组的值,不必做融合
if data_l == data_r:
self.tree[tree_index] = self.data[data_l]
# 不要忘记 return
return
# 然后一分为二去构建
mid = data_l + (data_r - data_l) // 2
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
self.__build_segment_tree(left_child, data_l, mid)
self.__build_segment_tree(right_child, mid + 1, data_r)
# 左右都构建好以后,再构建自己,因此是后续遍历
self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])
def __str__(self):
# 打印线段树
return str([str(ele) for ele in self.tree])
def query(self, data_l, data_r):
if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
raise Exception('Index is illegal.')
return self.__query(0, 0, len(self.data) - 1, data_l, data_r)
def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
# 一般而言,线段树区间肯定会大一些,所以会递归查询下去
# 如果要查询的线段树区间和数据区间完全吻合,把当前线段树索引的返回回去就可以了
if tree_l == data_l and tree_r == data_r:
return self.tree[tree_index]
mid = tree_l + (tree_r - tree_l) // 2
# 线段树的左右两个索引
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
# 因为构建时是这样
# self.__build_segment_tree(left_child, data_l, mid)
# 所以,如果右边区间不大于中间索引,就只须要在左子树查询就可以了
if data_r <= mid:
return self.__query(left_child, tree_l, mid, data_l, data_r)
# 同理,如果左边区间 >= mid + 1,就只用在右边区间找就好了
# self.__build_segment_tree(right_child, mid + 1, data_r)
if data_l >= mid + 1:
return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
# 横跨两边的时候,先算算左边,再算算右边
left_res = self.__query(left_child, tree_l, mid, data_l, mid)
right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
return self.merge(left_res, right_res)
def __init__(self, nums):
"""
:type nums: List[int]
"""
if len(nums) > 0:
self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
if self.st is None:
return 0
return self.st.query(i, j)
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(i,j)
Java 代码:可以点击 这里 查看。
区间更新
想一想更新的步骤,根据画图分析。从树的根开始更新,先把数据更新了,再更新 tree。set
方法 的设计与实现,其实是程式化的,这个过程熟悉了以后写起来,就会比较自然。最后不要忘记 merge 一下,从叶子结点开始,父辈结点,祖辈结点,直到根结点都要更新。
Java 代码:
public void set(int dataIndex, E val) {
if (dataIndex < 0 || dataIndex >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
data[dataIndex] = val;
set(0, 0, data.length - 1, dataIndex, val);
}
Java 代码:
private void set(int treeIndex, int l, int r, int dataIndex, E val) {
if (l == r) {
// 来到平衡二叉树的叶子点,这一步是最底层的更新操作
tree[treeIndex] = val;
return;
}
// 更新祖辈结点,还是先更新左边孩子和右边孩子,再更新
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
if (dataIndex >= mid + 1) {
// 到右边更新
set(rightTreeIndex, mid + 1, r, dataIndex, val);
}
if (dataIndex <= mid) {
// 到左边更新
set(leftTreeIndex, l, mid, dataIndex, val);
}
tree[treeIndex] = merge.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
LeetCode 上第 307 号问题:区域和检索 - 数组可修改
传送门:307. 区域和检索 - 数组可修改。
给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。
示例:
Given nums = [1, 3, 5] sumRange(0, 2) -> 9 update(1, 2) sumRange(0, 2) -> 8
说明:
- 数组仅可以在 update 函数下进行修改。
- 你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。
思路1:基于 cumsum 数组的写法,效率不高。
说明:这道题如果采用 cumsum 数组的实现,会得到一个 TLE 的结果。但是采用线段树的实现,就能很容易通过。多看几遍,就明白是怎么回事了。
Python 代码:
class NumArray:
class SegmentTree:
def __init__(self, arr, merge):
self.data = arr
# 开 4 倍大小的空间
self.tree = [None for _ in range(4 * len(arr))]
if not hasattr(merge, '__call__'):
raise Exception('不是函数对象')
self.merge = merge
self.__build_segment_tree(0, 0, len(self.data) - 1)
def get_size(self):
return len(self.data)
def get(self, index):
if index < 0 or index >= len(self.data):
raise Exception("Index is illegal.")
return self.data[index]
def __left_child(self, index):
return 2 * index + 1
def __right_child(self, index):
return 2 * index + 2
def __build_segment_tree(self, tree_index, data_l, data_r):
# 区间只有 1 个数的时候,线段树的值,就是数组的值,不必做融合
if data_l == data_r:
self.tree[tree_index] = self.data[data_l]
# 不要忘记 return
return
# 然后一分为二去构建
mid = data_l + (data_r - data_l) // 2
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
self.__build_segment_tree(left_child, data_l, mid)
self.__build_segment_tree(right_child, mid + 1, data_r)
# 左右都构建好以后,再构建自己,因此是后续遍历
self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])
def __str__(self):
# 打印线段树
return str([str(ele) for ele in self.tree])
def query(self, data_l, data_r):
if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
raise Exception('Index is illegal.')
return self.__query(0, 0, len(self.data) - 1, data_l, data_r)
def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
# 一般而言,线段树区间肯定会大一些,所以会递归查询下去
# 如果要查询的线段树区间和数据区间完全吻合,把当前线段树索引的返回回去就可以了
if tree_l == data_l and tree_r == data_r:
return self.tree[tree_index]
mid = tree_l + (tree_r - tree_l) // 2
# 线段树的左右两个索引
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
# 因为构建时是这样
# self.__build_segment_tree(left_child, data_l, mid)
# 所以,如果右边区间不大于中间索引,就只须要在左子树查询就可以了
if data_r <= mid:
return self.__query(left_child, tree_l, mid, data_l, data_r)
# 同理,如果左边区间 >= mid + 1,就只用在右边区间找就好了
# self.__build_segment_tree(right_child, mid + 1, data_r)
if data_l >= mid + 1:
return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
# 横跨两边的时候,先算算左边,再算算右边
left_res = self.__query(left_child, tree_l, mid, data_l, mid)
right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
return self.merge(left_res, right_res)
def set(self, data_index, val):
if data_index < 0 or data_index >= len(self.data):
raise Exception('Index is illegal.')
# 先把数据更新了
self.data[data_index] = val
# 线段树的更新递归去完成
self.__set(0, 0, len(self.data) - 1, data_index, val)
def __set(self, tree_index, tree_l, tree_r, data_index, val):
if tree_l == tree_r:
# 注意:这里不能填 tree_l 或者 tree_r
self.tree[tree_index] = val
return
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
mid = tree_l + (tree_r - tree_l) // 2
if data_index >= mid + 1:
# 如果在右边,就只去右边更新
self.__set(right_child, mid + 1, tree_r, data_index, val)
if data_index <= mid:
# 如果在左边,就只去左边更新
self.__set(left_child, tree_l, mid, data_index, val)
# 左边右边都更新完以后,再更新自己
self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.size = len(nums)
if self.size:
self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: void
"""
if self.size:
self.st.set(i, val)
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
if self.size:
return self.st.query(i, j)
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)
Java 代码:可以点击 这里 查看。
“自底向上”的线段树实现
自底向上的线段树实现只要使用 倍原始数组大小的辅助空间就。下面的 2 张图就展示了这个过程:
我们根据结点个数的奇偶性,分别讨论,但是,最终我们发现,奇偶性并不影响结论。
我们从下到上构建二叉树:
1、先把原始结点做一个偏移,偏移量是原始数组的长度;
2、“自底向上”构建,即父节点就是该结点的索引值除以 ,这个除法是整数除法;
我们发现,不论是奇数个结点还是偶数个结点最终都可以达到根结点,并且根结点的索引是 ,索引是 的位置我们不用。
规律(不论结点个数是奇数还是偶数都成立):父结点的索引如果是 i
,子结点的索引就是 2 * i
和 2 * i + 1
。
Python 代码:
class SegmentTree:
# 自底向上的线段树实现
def __init__(self, arr, merge):
self.data = arr
self.size = len(arr)
# 开 2 倍大小的空间
self.tree = [None for _ in range(2 * self.size)]
if not hasattr(merge, '__call__'):
raise Exception('不是函数对象')
self.merge = merge
# 原始数值赋值
for i in range(self.size, 2 * self.size):
self.tree[i] = self.data[i - self.size]
# 从后向前赋值
for i in range(self.size - 1, 0, -1):
self.tree[i] = self.merge(self.tree[2 * i], self.tree[2 * i + 1])
def get_size(self):
return len(self.data)
def query(self, l, r):
l += self.size
r += self.size
res = 0
while l <= r:
# 如果左端点是奇数
if l & 1 == 1:
if res == 0:
# 一开始要加上叶子结点
res = self.tree[l]
else:
res = self.merge(res, self.tree[l])
# 把左端点变成偶数
l += 1
if r & 1 == 0:
if res == 0:
# 一开始要加上叶子结点
res = self.tree[r]
else:
res = self.merge(res, self.tree[r])
# 把右端点变成奇数
r -= 1
# 往叶子结点上走,所以是除以 2
l //= 2
r //= 2
return res
def set(self, i, val):
i += self.size
self.tree[i] = val
while i > 0:
left = i
right = i
if i & 1 == 0:
right = i + 1
else:
left = i - 1
if left == 0:
self.tree[i // 2] = self.tree[right]
else:
self.tree[i // 2] = self.merge(self.tree[left], self.tree[right])
i //= 2
if __name__ == '__main__':
nums = [-2, 0, 3, -5, 2, -1]
st = SegmentTree(nums, lambda a, b: a + b)
result1 = st.query(0, 2)
print(result1)
result2 = st.query(2, 5)
print(result2)
result3 = st.query(0, 5)
print(result3)
Java 代码:
public class NumArray {
private SegmentTree segmentTree;
public NumArray(int[] nums) {
Merger merger = new Merger() {
@Override
public Integer merge(Integer e1, Integer e2) {
return e1 + e2;
}
};
Integer[] arr = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
arr[i] = nums[i];
}
segmentTree = new SegmentTree(arr, merger);
}
public void update(int i, int val) {
segmentTree.set(i, val);
}
public int sumRange(int i, int j) {
return segmentTree.query(i, j);
}
private interface Merger {
E merge(E e1, E e2);
}
private class SegmentTree {
private E[] tree;
private int len;
private Merger merger;
private SegmentTree(E[] arr, Merger merger) {
this.merger = merger;
len = arr.length;
tree = (E[]) new Object[len * 2];
for (int i = len; i < 2 * len; i++) {
tree[i] = arr[i - len];
}
for (int i = len - 1; i > 0; i--) {
tree[i] = merger.merge(tree[2 * i], tree[2 * i + 1]);
}
}
public E query(int l, int r) {
l += len;
r += len;
E res = null;
while (l <= r) {
if (l % 2 == 1) {
if (res == null) {
res = tree[l];
} else {
res = merger.merge(res, tree[l]);
}
l++;
}
if (r % 2 == 0) {
if (res == null) {
res = tree[r];
} else {
res = merger.merge(res, tree[r]);
}
r--;
}
l /= 2;
r /= 2;
}
return res;
}
public void set(int i, E val) {
i += len;
tree[i] = val;
while (i > 0) {
int left = i;
int right = i;
// i 是左边结点
if (i % 2 == 0) {
right = i + 1;
} else {
left = i - 1;
}
if (left == 0) {
tree[i / 2] = tree[right];
} else {
tree[i / 2] = merger.merge(tree[left], tree[right]);
}
i /= 2;
}
}
}
}
本文源代码
Python:代码文件夹,Java:代码文件夹。
参考资料
1、B 站上一位 UP 主的讲解:线段树入门。
博客地址:https://wmathor.com
(本节完)