想要精通算法和SQL的成长之路 - 系列导航
我们先来看下线段树的含义:
Segment Tree
):是一种解决 区间查询问题 的数据结构。线段树通常用于解决以下类型的问题:
那么线段树的构建过程,通常是一个 分治递归 的过程。线段树的时间复杂度为O(logN)
。
例如我们有个数组:[1,2,3,4,5]
,它的区间和用线段树表示就是:
首先我们思考一下,从上图中,我们一个节点需要包括哪些数据:
那么不难得出,代码结构如下:
public class SegmentTreeNode {
public SegmentTreeNode left;
public SegmentTreeNode right;
public int start;
public int end;
public int val;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
}
}
我们往往针对的是一个数组,去构建它的线段树:
public class SegmentTree {
// 线段树的根节点
private SegmentTreeNode root;
public SegmentTree(int[] nums) {
root = buildTree(nums, 0, nums.length - 1);
}
private SegmentTreeNode buildTree(int[] nums, int start, int end) {
if (start > end) {
return null;
}
// 构建当前节点
SegmentTreeNode node = new SegmentTreeNode(start, end);
// 如果区间长度为1,那么当前节点的sum就是区间内的值
if (start == end) {
node.val = nums[start];
} else {
// 开始递归构建左右子树
int mid = (start + end) >> 1;
node.left = buildTree(nums, start, mid);
node.right = buildTree(nums, mid + 1, end);
// 当前节点的sum等于左右子树的sum之和
node.val = node.left.val + node.right.val;
}
return node;
}
}
构建完之后,我们需要计算区间和了:传入指定的区间queryStart
和queryEnd
,返回 [queryStart,queryEnd]
区间内的总和。
public int query(int queryStart, int queryEnd) {
return queryHelper(root, queryStart, queryEnd);
}
private int queryHelper(SegmentTreeNode node, int queryStart, int queryEnd) {
// 如果当前节点为空,或者当前节点的区间和不在查询区间内,那么返回0
if (node == null || queryStart > node.end || queryStart < node.start) {
return 0;
}
// 如果当前节点的区间完全在查询区间内,那么直接返回当前节点的sum。
// 例如我们要查询[2,4] 的区间和,[2, 4] = [2, 2] + [3, 4]. 当前节点的区间是 [2,2] 或者 [3,4],所以直接返回当前节点的sum即可。
if (node.start >= queryStart && node.end <= queryEnd) {
return node.val;
}
// 否则,我们需要递归查询左右子树
int mid = (node.start + node.end) >> 1;
// 注意这里的Math.min和Math.max,因为我们的查询区间是[queryStart, queryEnd],而当前节点的区间是[node.start, node.end],所以我们需要取交集
int leftSum = queryHelper(node.left, queryStart, Math.min(queryEnd, mid));
int rightSum = queryHelper(node.right, Math.max(queryStart, mid + 1), queryEnd);
return leftSum
}
上面的区间和计算,往往是基于数组的值不变的情况下进行的。那么假若数组中的某个元素被更新了,那么我们的区间和就不正确了,跟这个元素有关的所有链路,都要被更新,因此我们还需要准备一个更新的函数,它和查询非常类似,同样是递归操作。
public void update(int index, int newVal) {
updateHelper(root, index, newVal);
}
private void updateHelper(SegmentTreeNode node, int index, int newVal) {
if (node == null || index < node.start || index > node.end) {
return;
}
// 如果当前节点的区间就是要更新的区间,那么直接更新当前节点的sum即可
if (node.start == node.end) {
node.val = newVal;
return;
}
// 否则,我们需要递归更新左右子树
int mid = (node.start + node.end) >> 1;
if (index <= mid) {
updateHelper(node.left, index, newVal);
} else {
updateHelper(node.right, index, newVal);
}
// 更新完左右子树之后,需要更新当前节点的sum
node.val = node.left.val + node.right.val;
}
public class SegmentTree {
// 线段树的根节点
private SegmentTreeNode root;
public SegmentTree(int[] nums) {
root = buildTree(nums, 0, nums.length - 1);
}
private SegmentTreeNode buildTree(int[] nums, int start, int end) {
if (start > end) {
return null;
}
// 构建当前节点
SegmentTreeNode node = new SegmentTreeNode(start, end);
// 如果区间长度为1,那么当前节点的sum就是区间内的值
if (start == end) {
node.val = nums[start];
} else {
// 开始递归构建左右子树
int mid = (start + end) >> 1;
node.left = buildTree(nums, start, mid);
node.right = buildTree(nums, mid + 1, end);
// 当前节点的sum等于左右子树的sum之和
node.val = node.left.val + node.right.val;
}
return node;
}
public int query(int queryStart, int queryEnd) {
return queryHelper(root, queryStart, queryEnd);
}
private int queryHelper(SegmentTreeNode node, int queryStart, int queryEnd) {
// 如果当前节点为空,或者当前节点的区间和不在查询区间内,那么返回0
if (node == null || queryStart > node.end || queryStart < node.start) {
return 0;
}
// 如果当前节点的区间完全在查询区间内,那么直接返回当前节点的sum。
// 例如我们要查询[2,4] 的区间和,[2, 4] = [2, 2] + [3, 4]. 当前节点的区间是 [2,2] 或者 [3,4],所以直接返回当前节点的sum即可。
if (node.start >= queryStart && node.end <= queryEnd) {
return node.val;
}
// 否则,我们需要递归查询左右子树
int mid = (node.start + node.end) >> 1;
// 注意这里的Math.min和Math.max,因为我们的查询区间是[queryStart, queryEnd],而当前节点的区间是[node.start, node.end],所以我们需要取交集
int leftSum = queryHelper(node.left, queryStart, Math.min(queryEnd, mid));
int rightSum = queryHelper(node.right, Math.max(queryStart, mid + 1), queryEnd);
return leftSum + rightSum;
}
public void update(int index, int newVal) {
updateHelper(root, index, newVal);
}
private void updateHelper(SegmentTreeNode node, int index, int newVal) {
if (node == null || index < node.start || index > node.end) {
return;
}
// 如果当前节点的区间就是要更新的区间,那么直接更新当前节点的sum即可
if (node.start == node.end) {
node.val = newVal;
return;
}
// 否则,我们需要递归更新左右子树
int mid = (node.start + node.end) >> 1;
if (index <= mid) {
updateHelper(node.left, index, newVal);
} else {
updateHelper(node.right, index, newVal);
}
// 更新完左右子树之后,需要更新当前节点的sum
node.val = node.left.val + node.right.val;
}
}
第一节的代码,它有一个前提:
那如果我们不知道数组包含的元素个数以及各个元素值的时候,怎么去建立这颗线段树?如果可以,它必定是在不断地更新的基础上去动态扩建的。
还是以求区间和为例,我们试想一下,在动态扩建的过程中,每个新节点需要向其他节点传递什么信息?
add
来标识。start
和end
属性。public static class SegmentTreeNode {
public SegmentTreeNode left;
public SegmentTreeNode right;
public int add;
public int val;
public SegmentTreeNode() {}
}
我们定义一个pushDown
函数,它的功能有这么几个:
add
值,以及计算他们的区间和val
。add
值置为0。考虑到add值的传递,以及树中叶子节点的性质,除了当前节点node
,我们还需要两个变量来标识当前节点的左孩子数量和右孩子数量。
代码如下:
private void pushDown(SegmentTreeNode node, int leftNum, int rightNum) {
// 动态开点
if (node.left == null) {
node.left = new SegmentTreeNode();
}
if (node.right == null) {
node.right = new SegmentTreeNode();
}
// 如果当前节点的add值为0,那么我们不需要更新子节点的add值
if (node.val == 0) {
return;
}
// 否则,更新左右子节点的add值
node.left.add += node.add * leftNum;
node.right.add += node.add * rightNum;
// 更新左右子节点的add值
node.left.val += node.add;
node.right.val += node.add;
// 更新当前节点的add值
node.add = 0;
}
我们定义一个函数pushUp
,主要用来计算当前节点的区间值:
private void pushUp(SegmentTreeNode node) {
node.val = node.left.val + node.right.val;
}
/**
* @param node 线段树的根节点
* @param start 线段树的起始位置
* @param end 线段树的结束位置
* @param left 查询区间的左边界
* @param right 查询区间的右边界
* @param addValue 比原本的值多加的值
*/
public void update(SegmentTreeNode node, int start, int end, int left, int right, int addValue) {
// 如果线段树的区间完全在查询区间内,那么直接更新当前节点的add值即可
if (start >= left && end <= right) {
// 该区间内,所有叶子节点都要加上val
node.val += (end - start + 1) * addValue;
// 该区间内,所有非叶子节点都要加上val,传递给后面的新节点
node.add += addValue;
return;
}
// 如果不在查询区间内,那么我们需要递归更新左右子树
int mid = (start + end) >> 1;
// 向下传递标记
pushDown(node, mid - start + 1, end - mid);
if (left <= mid) {
update(node.left, start, mid, left, right, addValue);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
update(node.right, mid + 1, end, left, right, addValue);
}
// 计算当前节点的val值
pushUp(node);
}
public int query(SegmentTreeNode node, int start, int end, int left, int right) {
if (left <= start && end <= right) {
return node.val;
}
// 把当前区间 [start, end] 均分得到左右孩子的区间范围
int mid = (start + end) >> 1, ans = 0;
// 下推标记
pushDown(node, mid - start + 1, end - mid);
// [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
if (left <= mid) {
ans += query(node.left, start, mid, left, right);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
ans += query(node.right, mid + 1, end, left, right);
}
return ans;
}
public class SegmentTreeDynamic {
public static class SegmentTreeNode {
public SegmentTreeNode left;
public SegmentTreeNode right;
public int add;
public int val;
public SegmentTreeNode() {
}
}
/**
* @param node 线段树的根节点
* @param leftNum 左节点的叶子数量
* @param rightNum 右节点的叶子数量
*/
private void pushDown(SegmentTreeNode node, int leftNum, int rightNum) {
// 动态开点
if (node.left == null) {
node.left = new SegmentTreeNode();
}
if (node.right == null) {
node.right = new SegmentTreeNode();
}
// 如果当前节点的add值为0,那么我们不需要更新子节点的add值
if (node.val == 0) {
return;
}
// 否则,更新左右子节点的add值
node.left.add += node.add * leftNum;
node.right.add += node.add * rightNum;
// 更新左右子节点的add值
node.left.val += node.add;
node.right.val += node.add;
// 更新当前节点的add值
node.add = 0;
}
private void pushUp(SegmentTreeNode node) {
node.val = node.left.val + node.right.val;
}
/**
* @param node 线段树的根节点
* @param start 线段树的起始位置
* @param end 线段树的结束位置
* @param left 查询区间的左边界
* @param right 查询区间的右边界
* @param addValue 比原本的值多加的值
*/
public void update(SegmentTreeNode node, int start, int end, int left, int right, int addValue) {
// 如果线段树的区间完全在查询区间内,那么直接更新当前节点的add值即可
if (start >= left && end <= right) {
// 该区间内,所有叶子节点都要加上val
node.val += (end - start + 1) * addValue;
// 该区间内,所有非叶子节点都要加上val,传递给后面的新节点
node.add += addValue;
return;
}
// 如果不在查询区间内,那么我们需要递归更新左右子树
int mid = (start + end) >> 1;
// 向下传递标记
pushDown(node, mid - start + 1, end - mid);
if (left <= mid) {
update(node.left, start, mid, left, right, addValue);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
update(node.right, mid + 1, end, left, right, addValue);
}
// 计算当前节点的val值
pushUp(node);
}
public int query(SegmentTreeNode node, int start, int end, int left, int right) {
if (left <= start && end <= right) {
return node.val;
}
// 把当前区间 [start, end] 均分得到左右孩子的区间范围
int mid = (start + end) >> 1, ans = 0;
// 下推标记
pushDown(node, mid - start + 1, end - mid);
// [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
if (left <= mid) {
ans += query(node.left, start, mid, left, right);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
ans += query(node.right, mid + 1, end, left, right);
}
return ans;
}
}
Demo
如下:数组:[1, 3, 5, 7, 9, 11]
。
[1,4]
的区间和。[1,4]
的区间和。public static void main(String[] args) {
int[] nums = {1, 3, 5, 7, 9, 11};
SegmentTree segmentTree = new SegmentTree(nums);
int sum = segmentTree.query(1, 4); // 查询区间[1, 4]的和 3+5+7+9=24
System.out.println(sum); // 输出:24
segmentTree.update(2, 19); // 将索引2处的值更新为19
sum = segmentTree.query(1, 4); // 再次查询区间[1, 4]的和
System.out.println(sum); // 输出:38
}
public static void main(String[] args) {
int[] nums = {1, 3, 5, 7, 9, 11};
SegmentTreeDynamic segmentTree = new SegmentTreeDynamic();
SegmentTreeDynamic.SegmentTreeNode root = new SegmentTreeDynamic.SegmentTreeNode();
int n = nums.length - 1;
for (int i = 0; i < nums.length; i++) {
segmentTree.update(root, 0, n, i, i, nums[i]);
}
System.out.println(segmentTree.query(root, 0, n, 1, 4));
segmentTree.update(root, 0, n, 2, 2, 14);// 在原本值的基础上再多14,相当于定长计算中的5 + 14 = 19
System.out.println(segmentTree.query(root, 0, n, 1, 4));
}