线段树是一种用于处理区间查询的数据结构,特别适用于需要高效支持区间修改和区间查询的场景。它将一个数组表示的区间划分为一系列小区间,并将每个小区间的信息存储在树中。本文将介绍线段树的基本原理、实现方法和一些常见应用。
线段树是一种二叉树,每个节点代表一个区间。树的根节点表示整个数组,每个叶子节点表示数组中的一个元素,而每个中间节点表示数组的一个子区间。通过这种方式,线段树将数组分割为许多小区间,每个小区间都对应着树中的一个节点。
构建(Build): 将数组构建成线段树的过程。从树的叶子节点开始,每个节点的值表示对应区间的信息,中间节点的值通常由其子节点计算得到。
查询(Query): 查询某个区间的信息。通过递归地访问线段树的节点,可以高效地获取整个区间的信息。
更新(Update): 更新某个元素或区间的信息。通过递归地更新线段树的节点,可以保持树的信息与数组的一致性。
线段树主要用于解决区间查询和区间更新问题。它是一种二叉树数据结构,用于存储对一个线性结构(如数组或列表)中某个范围的数据进行快速查询和更新。
主要问题包括:
区间查询(Range Query): 求解给定区间内元素的某种属性,例如区间和、区间最小值、区间最大值等。
区间更新(Range Update): 修改给定区间内元素的值,例如增加或减少某个值,或者将区间内的元素全部设为某个值。
线段树通过将整个区间递归地划分为更小的子区间,每个节点存储了对应区间的信息,从而实现了高效的区间查询和更新操作。
#include
#include
struct SegmentTreeNode {
int start;
int end;
int sum; //区间和
SegmentTreeNode* left;
SegmentTreeNode* right;
SegmentTreeNode(int start, int end) : start(start), end(end), sum(0), left(nullptr), right(nullptr) {}
};
class SegmentTree {
private:
SegmentTreeNode* root;
// 递归构建线段树
SegmentTreeNode* buildTree(const std::vector<int>& nums, int start, int end) {
if (start > end) {
return nullptr;
}
SegmentTreeNode* node = new SegmentTreeNode(start, end);
if (start == end) { //叶节点
node->sum = nums[start];
} else {
int mid = (start + end) / 2;
node->left = buildTree(nums, start, mid);
node->right = buildTree(nums, mid + 1, end);
node->sum = node->left->sum + node->right->sum; //左右子树的和
}
return node;
}
// 区间查询
int queryRange(SegmentTreeNode* node, int left, int right) {
if (!node || left > node->end || right < node->start) {
return 0; // 区间不重叠,返回默认值
}
if (left <= node->start && right >= node->end) {
return node->sum; // 完全包含,返回当前节点的值
}
//如果 区间没有被完全包含,查找该区间的左/右子区间
int leftSum = queryRange(node->left, left, right);
int rightSum = queryRange(node->right, left, right);
return leftSum + rightSum;
}
// 单点更新
void updateNode(SegmentTreeNode* node, int index, int value) {
if (!node || index < node->start || index > node->end) {
return; //index超出索引范围
}
if (node->start == node->end && node->start == index) {
node->sum = value;
return;
}
updateNode(node->left, index, value);
updateNode(node->right, index, value);
node->sum = node->left->sum + node->right->sum; //更新父节点
}
public:
SegmentTree(const std::vector<int>& nums) {
int n = nums.size();
root = buildTree(nums, 0, n - 1);
}
// 区间查询接口
int query(int left, int right) {
return queryRange(root, left, right);
}
// 单点更新接口
void update(int index, int value) {
updateNode(root, index, value);
}
};
int main() {
std::vector<int> nums = {1, 3, 5, 7, 9, 11};
SegmentTree segTree(nums);
// 区间查询
std::cout << "Sum in range [1, 4]: " << segTree.query(1, 4) << std::endl;
// 单点更新
segTree.update(2, 6);
std::cout << "Updated sum in range [1, 4]: " << segTree.query(1, 4) << std::endl;
return 0;
}
当你读到此处也许会觉得用线段树来求区间和,并不是那么好用,代码又长、效率也一般。相对于用前缀和来求解区间求和问题可以更快更简洁的求解该问题,实时也确实如此吗?当涉及到需要多次对区间的元素进行单点或者区间范围内修改时,你还会这样认为吗?下面然我们来看一看线段树如何解决区间修改、区间查询问题的吧!
在讲解区间修改之前我们先来思考一下如何进行区间修改,首先想到的就是对一个[l,r]区间逐一的进行单点修改,这确实能够满足区间修改的要求。但是你仔细想想逐一进行修改的时间复杂度,能否让我们满意。答案当然是否定的。
于此,便提出了lazytage标记。
区间修改的时候,我们按照如下原则:
1、如果当前区间被完全覆盖在目标区间里,讲这个区间的 node->sum = node->sum+k*(node->end - node->start+1)
2、如果没有完全覆盖,则先下传懒标记
3、如果这个区间的左儿子和目标区间有交集,那么搜索左儿子
4、如果这个区间的右儿子和目标区间有交集,那么搜索右儿子
然后查询的时候,将这个懒标记下传就好了,下面图解一下:
如图,区间 1 ∼ 4 分别是 1 、 2 、 3 、 4 我们要把 1 ∼ 3 区间+1 。因为 1 ∼ 2 区间被完全覆盖,所以将其 + 2 ,并将紫色的 lazytage+1,3 区间同理。
在完成上述步骤之后任需要执行以下操作:node->sum = node->left->sum + node->right->sum;
void updateRange(SegmentTreeNode* node, int left, int right, int delta) {
if (!node || left > node->end || right < node->start) {
return;
}
if (left <= node->start && right >= node->end) {
node->sum += (node->end - node->start + 1) * delta; //更新sum
node->lz += delta; // lazy累积
return;
}
// lazy下移
pushdown(node);
updateRange(node->left.get(), left, right, delta);
updateRange(node->right.get(), left, right, delta);
node->sum = node->left->sum + node->right->sum;
}
需要注意的是,在区间更新和区间查询之前需要先处理lazy,如果lazy不为0,需要将lazy不断下移到子节点中,然后将当前节点的lazy置为0。
其中的pushdown函数,就是把自己的lazy归零,并给自己的儿子加上,并让自己儿子的sum加上k*(r-l+1).
void pushdown(SegmentTreeNode* node) {
//如果lazy == 0 则返回
if (!node->lz) {
return;
}
node->left->lz += node->lz;
node->right->lz += node->lz;
node->left->sum += (node->left->end - node->left->start + 1) * node->lz;
node->right->sum += (node->right->end - node->left->start + 1) * node->lz;
node->lz = 0; //将当前节点的lazy置为0
}
#include
#include
#include
using namespace std;
struct SegmentTreeNode {
int start;
int end;
int sum;
int lz;
std::unique_ptr<SegmentTreeNode> left;
std::unique_ptr<SegmentTreeNode> right;
SegmentTreeNode(int start, int end) : start(start), end(end), sum(0), lz(0), left(nullptr), right(nullptr) {}
};
class SegmentTree {
private:
SegmentTreeNode* root;
// 递归构建线段树
SegmentTreeNode* buildTree(const std::vector<int>& nums, int start, int end) {
if (start > end) {
return nullptr;
}
SegmentTreeNode* node = new SegmentTreeNode(start, end);
if (start == end) {
node->sum = nums[start];
} else {
int mid = (start + end) / 2;
node->left.reset(buildTree(nums, start, mid));
node->right.reset(buildTree(nums, mid + 1, end));
node->sum = node->left->sum + node->right->sum;
}
return node;
}
// pushdown函数,处理lazy
void pushdown(SegmentTreeNode* node) {
//如果lazy == 0 则返回
if (!node->lz) {
return;
}
node->left->lz += node->lz;
node->right->lz += node->lz;
node->left->sum += (node->left->end - node->left->start + 1) * node->lz;
node->right->sum += (node->right->end - node->left->start + 1) * node->lz;
node->lz = 0; //将当前节点的lazy置为0
}
// 区间查询
int queryRange(SegmentTreeNode* node, int left, int right) {
if (!node || left > node->end || right < node->start) {
return 0;
}
if (left <= node->start && right >= node->end) {
return node->sum;
}
// lazy下移
pushdown(node);
int leftSum = queryRange(node->left.get(), left, right);
int rightSum = queryRange(node->right.get(), left, right);
return leftSum + rightSum;
}
// 区间更新
void updateRange(SegmentTreeNode* node, int left, int right, int delta) {
if (!node || left > node->end || right < node->start) {
return;
}
if (left <= node->start && right >= node->end) {
node->sum += (node->end - node->start + 1) * delta; //更新sum
node->lz += delta; // lazy累积
return;
}
// lazy下移
pushdown(node);
updateRange(node->left.get(), left, right, delta);
updateRange(node->right.get(), left, right, delta);
node->sum = node->left->sum + node->right->sum;
}
public:
SegmentTree(const std::vector<int>& nums) {
int n = nums.size();
root = buildTree(nums, 0, n - 1);
}
// 区间查询接口
int query(int left, int right) {
return queryRange(root, left, right);
}
// 区间更新接口
void update(int left, int right, int delta) {
updateRange(root, left, right, delta);
}
};
int main() {
std::vector<int> nums = {1, 3, 5, 7, 9, 11};
SegmentTree segTree(nums);
// 区间查询
std::cout << "Sum in range [1, 4]: " << segTree.query(1, 4) << std::endl;
// 区间增加
segTree.update(1, 2, 2);
// 更新后的区间查询
std::cout << "Updated sum in range [1, 4]: " << segTree.query(1, 4) << std::endl;
return 0;
}