前段时间写过一篇关于树状数组的博客树状数组,今天我们要介绍的是线段树,线段树比树状数组中的应用场景更加的广泛。这些问题也是在leetcode 11月的每日一题频繁遇到的问题,实际上线段树就和红黑树 、堆一样是一类模板,但是标准库里面并没有(所以题目的代码量会比较大)。如果我们深刻了解了其中的原理,刷到的时候默写出来问题也不是很大。
如果没有看过线段树博客的先去看那个,再来看这个问题
在树状数组中我们的问题是:
给你一个数组 nums ,请你完成两类查询。
其中一类查询要求 更新 数组 nums 下标对应的值
另一类查询要求返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的 和 ,其中 left <= right
实现 NumArray 类:
NumArray(int[] nums)
用整数数组 nums 初始化对象void update(int index, int val)
将 nums[index] 的值 更新 为 valint sumRange(int left, int right)
返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的 和 (即,nums[left] + nums[left + 1], …, nums[right])
在这个问题中我们对区间的修改始终是单点修改,如果我们想修改一个区间的值(指给这个区间的所有值都加、减一个数),这时候树状数组只能遍历这个区间,然后对区间每一个数做单点修改,这样修改的时间复杂度就是M*logN
,M为区间长度,N为整个区间的大小。
但是如果用线段树来解决,每次区间修改的时间复杂度可以降到logN
线段树不像前面介绍的树状数组一样,树状数组逻辑结构是树,但是物理结构是一个数组。而线段树是一个真正的树型结构。
假设我们有一个长度为10 的数组。如果我们要构建一个线段树一定是下面这个结构:
我们观察可以发现:
[left,right]
,那么子节点的区间分别为[left,mid]
、[mid+1,right]
其中mid=(left+right)/2
根据上图我们很轻松的就可以定义出树节点的数据结构。
struct STNode {
STNode* left; // 左节点
STNode* right; // 右节点
int val; // 维护的区间的值(可能是区间和、区间的最值)
int add; //这个是lazy标记,我们后面会介绍,这里我们可以先忽略
};
实际上这个建立树的过程,后面会讲到线段树的动态开点,否则leetcode的内存会爆掉。
假设我们给定一个数组,要求你根据这个数组构建一个线段树,这个线段树的每一个区间维护的都是这个区间的区间和。
class SegmentTree {
public:
void PushUp(STNode *cur) {
cur->val = cur->left->val + cur->right->val;
}
void BuildTree(const vector<int>& v) {
N = v.size(); // 用于记录线段树查询区间的大小
function<void(int,int,STNode*&)> dfs = [&](int l,int r,STNode*& cur) {
if (l == r) {
cur=new STNode{ nullptr,nullptr,v[l],0};
return;
}
int mid = (r - l) / 2 + l;
cur=new STNode{ nullptr,nullptr,0,0 };
dfs(l, mid, cur->left);
dfs(mid + 1, r, cur->right);
PushUp(cur);
};
dfs(0, v.size() - 1, root);
}
private:
STNode* root = nullptr; // 根节点
int N; // 用于记录线段树查询区间的大小
};
注意这个递归建立线段树,一定要选对遍历的方式——后续遍历,因为只要一个节点有子节点,那么他的值一定是在两个子节点已经遍历完了之后才能确定,这符合后续遍历的定义:先遍历完左子树和右子树,最后两个子树的信息汇总到根节点。
这样一颗完整的线段树便可以确定。注意这里的PushUp
函数就是遍历完左右子树后的更新根节点,由于后面的查询和区间修改都需要这个操作所以这里需要重点理解一下!
线段树查询的思路是:
假设被查询的区间是[left,right] ,当前节点的区间为 [start , end]
left <=start && right>=end
则返回区间的值mid = (start +end)/2
求出其两个子节点的区间,然后将判断查询区间和哪个子节点相交,继续向子节点向下递归,直到找到第一种情况为止对于区间查询,比方说我们需要查询[2,6]
这个区间的区间和,红色的是查询的路径,蓝色是最终组成这个[2,6]
区间的节点,也是查询的结束的地方。
很简单,我们顺着上面的逻辑很快就能写出查询某个区间的代码,这里就不做过多的解释了
class SegmentTree {
public:
void PushUp(STNode *cur) {
cur->val = cur->left->val + cur->right->val;
}
// 区间查询
int Query(int left,int right) {
// left right 为要查询的区间 start end 为当前节点所维护的区间
function<int(int, int, int, int, STNode*)> dfs = [&](int left, int right,int start ,int end, STNode* cur) {
if (left <= start && right >= end) {
return cur->val;
}
int mid = (end - start) / 2 + start;
int ansl = 0, ansr = 0;
if (left > mid) {
ansl = dfs(left, right, mid + 1, end, cur->right);
}
else if (right <=mid) {
ansr = dfs(left, right, start, mid, cur->left);
}
else {
ansl = dfs(left, right, mid + 1, end, cur->right);
ansr = dfs(left, right, start, mid, cur->left);
}
//PushUp(cur); // 向上更新
return ansl + ansr;
};
return dfs(left, right, 0, N, root);
}
private:
STNode* root = nullptr; // 根节点
int N; // 用于记录线段树查询区间的大小
};
这里我们讨论的区间修改只的是让区间[left,right]
都加上或减去一个val
到这里我们来分析一下如何进行区间修改,这里我们和区间查询的思路一致:
假设被查询的区间是[left,right] ,当前节点的区间为 [start , end]
left <=start && right>=end
则更新区间的值,但是这里注意,我们更新了大区间的值,下面他的所有小区间都需要更新 ,下面的代码PushDown
函数就是处理这个问题的。mid = (start +end)/2
求出其两个子节点的区间,然后将判断查询区间和哪个子节点相交,继续向子节点向下递归,直到找到第一种情况为止class SegmentTree {
public:
void PushUp(STNode *cur) {
cur->val = cur->left->val + cur->right->val;
}
void PushDown(int l,int r,STNode* cur,int val) {
if (cur == nullptr) {
return;
}
cur->val += (r - l+1) * val;
int mid = (r - l) / 2 + l;
PushDown(l, mid, cur->left, val);
PushDown(mid + 1, r, cur->right, val);
}
void Update(int left, int right,int val) {
// left right 为要查询的区间 start end 为当前节点所维护的区间 val为区间修改的值
function<void(int, int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, int val, STNode* cur) {
if (left <= start && right >= end) {
PushDown(start, end, cur, val); // 将cur区间所有子节点全部更新了
return;
}
int mid = (end - start) / 2 + start;
int ansl = 0, ansr = 0;
if (left > mid) {
dfs(left, right, mid + 1, end, val, cur->right);
}
else if (right <= mid) {
dfs(left, right, start, mid, val, cur->left);
}
else {
dfs(left, right, mid + 1, val, end, cur->right);
dfs(left, right, start, mid, val, cur->left);
}
PushUp(cur); // 向上更新 值,因为修改了子区间 其所有父区间都会被修改!
};
dfs(left, right, 0, N,val, root);
}
private:
STNode* root = nullptr; // 根节点
int N; // 用于记录线段树查询区间的大小
};
代码写到这里,线段树的功能已经被我们全部实现了,但是我们回国头来看一下线段树的实现思路,看看是否还有优化的地方。
我们主要把目光放在 线段树的修改上面,我们上面线段树的修改的思路实际上和树状数组的遍历区间的单点修改是没有什么区别的,并没有什么优势,时间复杂度都是:时间复杂度就是M*logN
,M为区间长度,N为整个区间的大小。
所以我们接下来要介绍一下优化方案——Lazy标记
我们的区间修改 的优化方案 主要集中在PushDown
这个函数上,我们当前的区间修改,对于节点的区间被查询区间包含的情况,是修改当前节点,并将修改传递给当前节点的每一个子节点。而我们的Lazy
标记实际上就是优化这一步。
实际上我们发现我们修改完当前的节点时候(对于节点的区间被查询区间包含的情况),实际上并不需要将子节点修改,或者换一种说法:不需要立刻修改所有的子节点,我们可以设置一个Lazy标记。
这个Lazy标记的定义是:
记录所有子节点的区间应当被修改,但是未被修改的值。(这些值先由父节点记录,等到后面如果要访问到下面的节点的时候,顺路把标记带下去)
这些懒标记,我们可以在后面查询的时候,比方说访问到懒标记所在节点的子节点的时候,在去把Lazy标记下放,也就是把Lazy标记传递到子节点。
假设我有一个全0的长度位10 的数组,需要更新区间[0,6]
这时候按照 Lazy标记的 定义,我们遍历到区间包含节点区间的时候,就不必向下更新了,而是直接修改节点的Lazy标记,
这样
我们回过头来看,树节点的设计中我们一直都留着一个add
,这个就是每个树节点的Lazy标记
struct STNode {
STNode* left;
STNode* right;
int val;
int add; // Lazy标记
};
我们更新[0,6]
这个区间,递归结束的节点始终满足节点区间被查询区间所包含,注意更新的值为:区间长度*val
但是别忘了最后的回溯更新,所以这次区间更新的最终结果是,黄色的线代表回溯更新的路线
假设我们这时候需要查询或者更新[0,2]
我们这里以更新区间[0,2]
,让区间里面的每一个元素加1为例:
当我们遍历到区间[0,2]
的时候,这时候就需要下放Lazy标记到他的子节点,注意这里下放的Lazy标记代表节点区间[0,2]
和节点区间[3,5]
理论上应该区间每一个元素+1,所以每个节点的val需要+3,然后两个子节点lazy标记需要加上父节点的Lazy标记(注意不是直接赋值),因为这两个子节点的子节点并没有更新Lazy标记。最后记得把父节点区间[0,5]
的lazy标记置0!
然后由于我们还需要更新[0,2]
节点的值,所以我们直接更新节点区间[0,2]
lazy标记,让Lazy标记+=1(因为要更新区间[0,2]
让区间每一个元素加1)
从上面的流程我们可以看出,Lazy标记实际上是一种向下传递,所以我们只需要修改PushDown
函数的逻辑既可,那我们就要思考一个问题:何时需要向下传递?
答案很简单:当查询区间不包含当前节点区间的时候,这时候查询的区间一定在当前区间的子节点中,所以需要向下递归,所以在向下递归之前我们需要下传Lazy标记!
这句话需要好好理解一下。
void PushDown(int l, int r, STNode* cur) {
if (cur->add == 0) // 如果当前节点没有Lazy标记,则不需要下传直接返回
return;
int mid = (r - l) / 2 + l;
// 将父节点的 Lazy 标记下传到
cur->left->add += cur->add;
cur->right->add += cur->add;
cur->left->val += (mid - l + 1) * cur->add;
cur->right->val += (r - mid) * cur->add;
cur->add = 0;
}
然后相应的Query
和Update
函数都要做出相应的修改
void Update(int left, int right, int val) {
// left right 为要查询的区间 start end 为当前节点所维护的区间 val为区间修改的值
function<void(int, int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, int val, STNode* cur) {
if (left <= start && right >= end) {
cur->val += (end - start + 1) * val;
cur->add += val;
return;
}
// 走到这里代表要查询的区间[left,right]一定在cur节点的子节点中! 一定要理解!!!!!
PushDown(start, end, cur);
int mid = (end - start) / 2 + start;
int ansl = 0, ansr = 0;
if (left > mid) {
dfs(left, right, mid + 1, end, val, cur->right);
}
else if (right <= mid) {
dfs(left, right, start, mid, val, cur->left);
}
else {
dfs(left, right, mid + 1, end, val, cur->right);
dfs(left, right, start, mid, val, cur->left);
}
PushUp(cur); // 向上更新父节点
};
dfs(left, right, 0, N, val, root);
}
int Query(int left, int right) {
// left right 为要查询的区间 start end 为当前节点所维护的区间
function<int(int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, STNode* cur) {
if (left <= start && right >= end) {
return cur->val;
}
// 走到这里代表要查询的区间[left,right]一定在cur节点的子节点中! 一定要理解!!!!!
PushDown(start, end, cur);
int mid = (end - start) / 2 + start;
int ansl = 0, ansr = 0;
if (left > mid) {
ansl = dfs(left, right, mid + 1, end, cur->right);
}
else if (right <= mid) {
ansr = dfs(left, right, start, mid, cur->left);
}
else {
ansl = dfs(left, right, mid + 1, end, cur->right);
ansr = dfs(left, right, start, mid, cur->left);
}
PushUp(cur);
return ansl + ansr;
};
return dfs(left, right, 0, N, root);
}
到这里实际上线段树就已经结束了,但是做题的时候如果这样搞理论上应该是无法通过的,因为题目的区间都是很大的,一般题目给的区间都是[1,1000000000]
,如果我们按照上面的构建线段树的方法,内存应该无法通过。与是这里我们又有了一种新的方法——动态开点
我们设想如果我们要查询的区间落在[1,1000000000]
上,如果我们上来就开辟1000000000个节点构成线段树,设想下面这种情况,我们的区间查询和更新都集中在[0,5000000000]
上,而区间[500000001,1000000000]
只被查询过一次,那么区间[500000001,1000000000]
下面500000000个节点实际上是不用开辟的!
空间上是极大的浪费,我们这里的处理的方法和Lazy标记一样:就是当我们用到该节点的时候,我们再去创建该节点。
而我们何时知道需要访问到子节点?
也就是查询区间不包含当前节点区间的时候,意味着目标节点肯定在子节点中,这时候我们再去开辟子节点
我们把这部分逻辑放在PushDown
函数里面去实现。
void PushDown(int l, int r, STNode* cur) {
// 创建子节点 如果没有的话
if (cur->left == nullptr) cur->left = new STNode{ nullptr,nullptr,0,0 };
if (cur->right == nullptr) cur->right = new STNode{ nullptr,nullptr,0,0 };
if (cur->add == 0)
return;
int mid = (r - l) / 2 + l;
cur->left->add += cur->add;
cur->right->add += cur->add;
cur->left->val += (mid - l + 1) * cur->add;
cur->right->val += (r - mid) * cur->add;
cur->add = 0;
}
到这里线段树的模板就已经写完了
struct STNode {
STNode* left;
STNode* right;
int val = 0;
int add;
};
class SegmentTree {
public:
void PushUp(STNode* cur) {
cur->val = cur->left->val + cur->right->val;
}
SegmentTree() {
N = 1000000000;
root = new STNode{ nullptr,nullptr,0,0 };
}
void PushDown(int l, int r, STNode* cur) {
if (cur->left == nullptr) cur->left = new STNode{ nullptr,nullptr,0,0 };
if (cur->right == nullptr) cur->right = new STNode{ nullptr,nullptr,0,0 };
if (cur->add == 0)
return;
int mid = (r - l) / 2 + l;
cur->left->add += cur->add;
cur->right->add += cur->add;
cur->left->val += (mid - l + 1) * cur->add;
cur->right->val += (r - mid) * cur->add;
cur->add = 0;
}
void Update(int left, int right, int val) {
// left right 为要查询的区间 start end 为当前节点所维护的区间 val为区间修改的值
function<void(int, int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, int val, STNode* cur) {
if (left <= start && right >= end) {
cur->val += (end - start + 1) * val;
cur->add += val;
return;
}
PushDown(start, end, cur);
int mid = (end - start) / 2 + start;
int ansl = 0, ansr = 0;
if (left > mid) {
dfs(left, right, mid + 1, end, val, cur->right);
}
else if (right <= mid) {
dfs(left, right, start, mid, val, cur->left);
}
else {
dfs(left, right, mid + 1, end, val, cur->right);
dfs(left, right, start, mid, val, cur->left);
}
PushUp(cur);
};
dfs(left, right, 0, N, val, root);
}
int Query(int left, int right) {
// left right 为要查询的区间 start end 为当前节点所维护的区间
function<int(int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, STNode* cur) {
if (left <= start && right >= end) {
return cur->val;
}
PushDown(start, end, cur);
int mid = (end - start) / 2 + start;
int ansl = 0, ansr = 0;
if (left > mid) {
ansl = dfs(left, right, mid + 1, end, cur->right);
}
else if (right <= mid) {
ansr = dfs(left, right, start, mid, cur->left);
}
else {
ansl = dfs(left, right, mid + 1, end, cur->right);
ansr = dfs(left, right, start, mid, cur->left);
}
PushUp(cur);
return ansl + ansr;
};
return dfs(left, right, 0, N, root);
}
private:
STNode* root = nullptr; // 根节点
int N; // 用于记录线段树查询区间的大小
};
但是在做题的时候还是会对模板进行修改,修改的点主要在区间val
的意义上,我们这里模板里面val
代表的是区间和,但后面题目里面可能是 区间最大值 、区间最小值…,但是万变不离其宗,线段树的结构是不会变的。
题目的详情 点击上面链接☝
这道题看完题目描述很明显是一道线段树的题目,题目要求我们每次返回被标记的节点的数量(一个节点不能被重复标记),我们可以统计区间元素的个数来解决这个问题, 但是我们很快发现了一个问题: 如果我们标记了一个区间的节点,但是下次需要标记的区间与该区间重合,如果只是简单的统计区间的数量 那么就会造成重叠区间元素被重复统计!
所以我们这题的解决思路是:
val
设置为区间长度既可(代表这个区间被标记了),并设置Lazy标记[0,5]
但是 前面标记过[0,2]
区间,这时节点的val=2
,但是我们当前需要标记[0,5]
,这时我们需要 将val
设置为当前区间的长度 即另val=6
,并设置Lazy标记PushDown
函数里面处理Lazy标记的时候,只要Lazy标记不为0,就把当前区间的val
置为节点区间的长度。JAVA版本
class STNode{
STNode(STNode l,STNode r,int val,int add){
this.left=l;
this.right=r;
this.val=val;
this.add=add;
}
STNode left;
STNode right;
int val=0;
int add=0;
}
class CountIntervals {
STNode root;
int N=1000000000;
public CountIntervals() {
root=new STNode(null,null,0,0);
}
private void PushDown(STNode cur,int begin,int end){
if(cur.left==null)
cur.left=new STNode(null,null,0,0);
if(cur.right==null)
cur.right=new STNode(null,null,0,0);
if(cur.add==0)
return;
int mid=(end-begin)/2+begin;
cur.left.val=(mid-begin+1); // Lazy标记向下传递到时候 ,只需要把val置为区间长度既可
cur.right.val=(end-mid);
cur.left.add++;
cur.right.add++;
cur.add=0;
}
private void PushUp(STNode cur){
cur.val=cur.left.val+cur.right.val;
}
void Update(int left,int right,int begin,int end,STNode cur){
if(left<=begin && right>=end){
if(cur.val!=(end-begin+1)){
cur.val=(end-begin+1);
cur.add++;
}
return ;
}
int mid=(end-begin)/2+begin;
PushDown(cur,begin,end);
if(left>mid){
Update(left,right,mid+1,end,cur.right);
}else if(right<=mid){
Update(left,right,begin,mid,cur.left);
}else{
Update(left,right,mid+1,end,cur.right);
Update(left,right,begin,mid,cur.left);
}
PushUp(cur);
}
public void add(int left, int right) {
Update(left,right,0,N,root);
}
public int count() {
PushDown(root,0,N);
PushUp(root);
return root.val;
}
}
C++版本
struct STNode {
STNode* left;
STNode* right;
int val;
int add;
};
class CountIntervals {
public:
STNode* root = nullptr;
void PushUp(STNode* cur) {
cur->val = cur->left->val + cur->right->val;
}
void PushDown(STNode* cur, int begin, int end) {
if (cur->left == nullptr)
cur->left = new STNode{ nullptr,nullptr,0,0 };
if (cur->right == nullptr)
cur->right = new STNode{ nullptr,nullptr,0,0 };
if (cur->add == 0)
return;
int mid = (end - begin) / 2 + begin;
cur->left->val = (mid-begin+1);
cur->right->val = (end - mid);
cur->left->add += cur->add;
cur->right->add += cur->add;
cur->add = 0;
}
void Update(int left, int right, int begin, int end, STNode* cur) {
if (left <= begin && right >= end) {
if (cur->val != end - begin + 1) {
cur->val = (end - begin + 1);
cur->add += 1;
}
return;
}
PushDown(cur,begin,end);
int mid = (end - begin) / 2 + begin;
if (left > mid)
Update(left, right, mid + 1, end, cur->right);
else if (right <= mid)
Update(left, right, begin, mid, cur->left);
else {
Update(left, right, mid + 1, end, cur->right);
Update(left, right, begin, mid, cur->left);
}
PushUp(cur);
}
CountIntervals() {
root = new STNode{ nullptr,nullptr,0,0 };
}
void add(int left, int right) {
Update(left, right, 0, 1000000000, root);
//cout<
}
int count() {
PushDown(root,0,1000000000);
root->val = root->left->val + root->right->val;
return root->val;
}
};
/**
* Your CountIntervals object will be instantiated and called as such:
* CountIntervals* obj = new CountIntervals();
* obj->add(left,right);
* int param_2 = obj->count();
*/