详解线段树

前段时间写过一篇关于树状数组的博客树状数组,今天我们要介绍的是线段树,线段树比树状数组中的应用场景更加的广泛。这些问题也是在leetcode 11月的每日一题频繁遇到的问题,实际上线段树就和红黑树 、堆一样是一类模板,但是标准库里面并没有(所以题目的代码量会比较大)。如果我们深刻了解了其中的原理,刷到的时候默写出来问题也不是很大。

文章目录

  • 问题引入
  • 线段树的引入
  • 树节点数据结构
  • 建立一个线段树
  • 线段树的查询
  • 线段树的区间修改
  • 小结
  • Lazy标记
    • 区间修改
      • 向下传递
    • 代码
  • 动态开点
  • 完整模板
  • 总结
  • 例题:
    • [我的日程安排表 1](https://leetcode.cn/problems/my-calendar-i/description/)
    • [我的日程安排表 II](https://leetcode.cn/problems/my-calendar-ii/)
    • [我的日程安排表 III](https://leetcode.cn/problems/my-calendar-iii/description/)
    • [715. Range 模块](https://leetcode.cn/problems/range-module/)
    • 统计区间中的整数数目


如果没有看过线段树博客的先去看那个,再来看这个问题

问题引入

在树状数组中我们的问题是:

给你一个数组 nums ,请你完成两类查询。
其中一类查询要求 更新 数组 nums 下标对应的值
另一类查询要求返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的 和 ,其中 left <= right
实现 NumArray 类:

  • NumArray(int[] nums) 用整数数组 nums 初始化对象
  • void update(int index, int val) 将 nums[index] 的值 更新 为 val
  • int sumRange(int left, int right) 返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的 和 (即,nums[left] + nums[left + 1], …, nums[right])

在这个问题中我们对区间的修改始终是单点修改,如果我们想修改一个区间的值(指给这个区间的所有值都加、减一个数),这时候树状数组只能遍历这个区间,然后对区间每一个数做单点修改,这样修改的时间复杂度就是M*logN,M为区间长度,N为整个区间的大小。

但是如果用线段树来解决,每次区间修改的时间复杂度可以降到logN

线段树的引入

线段树不像前面介绍的树状数组一样,树状数组逻辑结构是树,但是物理结构是一个数组。而线段树是一个真正的树型结构。
假设我们有一个长度为10 的数组。如果我们要构建一个线段树一定是下面这个结构:
详解线段树_第1张图片
我们观察可以发现:

  • 每个节点都是维护的一个区间的统计值(可以是区间和 区间的最大值 、最小值)
  • 所有的叶子节点的区间长度均为1,也就是数组的一个单点元素的值
  • 如果一个节点有子节点,且区间为[left,right] ,那么子节点的区间分别为[left,mid][mid+1,right] 其中mid=(left+right)/2

树节点数据结构

根据上图我们很轻松的就可以定义出树节点的数据结构。

struct STNode {
    STNode* left;    // 左节点
    STNode* right;   // 右节点
    int val;    // 维护的区间的值(可能是区间和、区间的最值)
    int add;  //这个是lazy标记,我们后面会介绍,这里我们可以先忽略
};

建立一个线段树

实际上这个建立树的过程,后面会讲到线段树的动态开点,否则leetcode的内存会爆掉。
假设我们给定一个数组,要求你根据这个数组构建一个线段树,这个线段树的每一个区间维护的都是这个区间的区间和。


class SegmentTree {
public:
    void PushUp(STNode *cur) {
        cur->val = cur->left->val + cur->right->val;
    }
    void BuildTree(const vector<int>& v) {
        N = v.size();  // 用于记录线段树查询区间的大小
        function<void(int,int,STNode*&)> dfs = [&](int l,int r,STNode*& cur) {
            if (l == r) {
                cur=new STNode{ nullptr,nullptr,v[l],0};
                return;
            }
            int mid = (r - l) / 2 + l;
            cur=new STNode{ nullptr,nullptr,0,0 };
            dfs(l, mid, cur->left);
            dfs(mid + 1, r, cur->right);

            PushUp(cur);
        };

        dfs(0, v.size() - 1, root);
    }
private:
    STNode* root = nullptr;   // 根节点
    int N;                    // 用于记录线段树查询区间的大小
};

注意这个递归建立线段树,一定要选对遍历的方式——后续遍历,因为只要一个节点有子节点,那么他的值一定是在两个子节点已经遍历完了之后才能确定,这符合后续遍历的定义:先遍历完左子树和右子树,最后两个子树的信息汇总到根节点。
这样一颗完整的线段树便可以确定。注意这里的PushUp函数就是遍历完左右子树后的更新根节点,由于后面的查询和区间修改都需要这个操作所以这里需要重点理解一下!

线段树的查询

线段树查询的思路是:

假设被查询的区间是[left,right] ,当前节点的区间为 [start , end]

  • 如果查询的区间正好包括了节点所代表的区间,即left <=start && right>=end 则返回区间的值
  • 如果查询的区间不包含节点的区间,那么我们可以通过 mid = (start +end)/2 求出其两个子节点的区间,然后将判断查询区间和哪个子节点相交,继续向子节点向下递归,直到找到第一种情况为止

对于区间查询,比方说我们需要查询[2,6]这个区间的区间和,红色的是查询的路径,蓝色是最终组成这个[2,6]区间的节点,也是查询的结束的地方。
详解线段树_第2张图片

很简单,我们顺着上面的逻辑很快就能写出查询某个区间的代码,这里就不做过多的解释了

class SegmentTree {
public:
    void PushUp(STNode *cur) {
        cur->val = cur->left->val + cur->right->val;
    }
    // 区间查询
    int Query(int left,int right) {
        // left right 为要查询的区间   start end 为当前节点所维护的区间
        function<int(int, int, int, int, STNode*)> dfs = [&](int left, int right,int start ,int end, STNode* cur) {
            if (left <= start && right >= end) {
                return cur->val;
            }
            int mid = (end - start) / 2 + start;
            int ansl = 0, ansr = 0;
            if (left > mid) {
                ansl = dfs(left, right, mid + 1, end, cur->right);
            }
            else if (right <=mid) {
                ansr = dfs(left, right, start, mid, cur->left);
            }
            else {
                ansl = dfs(left, right, mid + 1, end, cur->right);
                ansr = dfs(left, right, start, mid, cur->left);
            }
            //PushUp(cur);   // 向上更新
            return ansl + ansr;
        };
        return dfs(left, right, 0, N, root);
    }
private:
    STNode* root = nullptr;   // 根节点
    int N;                    // 用于记录线段树查询区间的大小
};

线段树的区间修改

这里我们讨论的区间修改只的是让区间[left,right] 都加上或减去一个val
到这里我们来分析一下如何进行区间修改,这里我们和区间查询的思路一致:

假设被查询的区间是[left,right] ,当前节点的区间为 [start , end]

  • 如果查询的区间正好包括了节点所代表的区间,即left <=start && right>=end 则更新区间的值,但是这里注意,我们更新了大区间的值,下面他的所有小区间都需要更新 ,下面的代码PushDown函数就是处理这个问题的。
  • 如果查询的区间不包含节点的区间,那么我们可以通过 mid = (start +end)/2 求出其两个子节点的区间,然后将判断查询区间和哪个子节点相交,继续向子节点向下递归,直到找到第一种情况为止
class SegmentTree {
public:
    void PushUp(STNode *cur) {
        cur->val = cur->left->val + cur->right->val;
    }
    void PushDown(int l,int r,STNode* cur,int val) {
        if (cur == nullptr) {
            return;
        }
        cur->val += (r - l+1) * val;
        int mid = (r - l) / 2 + l;
        PushDown(l, mid, cur->left, val);
        PushDown(mid + 1, r, cur->right, val);

    }
    void Update(int left, int right,int val) {
        // left right 为要查询的区间   start end 为当前节点所维护的区间  val为区间修改的值
        function<void(int, int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, int val, STNode* cur) {
            if (left <= start && right >= end) {
                PushDown(start, end, cur, val);  // 将cur区间所有子节点全部更新了
                return;
            }
            int mid = (end - start) / 2 + start;
            int ansl = 0, ansr = 0;
            if (left > mid) {
                dfs(left, right, mid + 1, end, val, cur->right);
            }
            else if (right <= mid) {
                dfs(left, right, start, mid, val, cur->left);
            }
            else {
                dfs(left, right, mid + 1, val, end, cur->right);
                dfs(left, right, start, mid, val, cur->left);
            }
            PushUp(cur);   // 向上更新 值,因为修改了子区间 其所有父区间都会被修改!
         };
        dfs(left, right, 0, N,val, root);
    }
private:
    STNode* root = nullptr;   // 根节点
    int N;                    // 用于记录线段树查询区间的大小
};

小结

代码写到这里,线段树的功能已经被我们全部实现了,但是我们回国头来看一下线段树的实现思路,看看是否还有优化的地方。
我们主要把目光放在 线段树的修改上面,我们上面线段树的修改的思路实际上和树状数组的遍历区间的单点修改是没有什么区别的,并没有什么优势,时间复杂度都是:时间复杂度就是M*logN,M为区间长度,N为整个区间的大小。
所以我们接下来要介绍一下优化方案——Lazy标记

Lazy标记

区间修改

我们的区间修改 的优化方案 主要集中在PushDown这个函数上,我们当前的区间修改,对于节点的区间被查询区间包含的情况,是修改当前节点,并将修改传递给当前节点的每一个子节点。而我们的Lazy标记实际上就是优化这一步。

实际上我们发现我们修改完当前的节点时候(对于节点的区间被查询区间包含的情况),实际上并不需要将子节点修改,或者换一种说法:不需要立刻修改所有的子节点,我们可以设置一个Lazy标记。
这个Lazy标记的定义是:
记录所有子节点的区间应当被修改,但是未被修改的值。(这些值先由父节点记录,等到后面如果要访问到下面的节点的时候,顺路把标记带下去)
这些懒标记,我们可以在后面查询的时候,比方说访问到懒标记所在节点的子节点的时候,在去把Lazy标记下放,也就是把Lazy标记传递到子节点。

假设我有一个全0的长度位10 的数组,需要更新区间[0,6]
详解线段树_第3张图片

这时候按照 Lazy标记的 定义,我们遍历到区间包含节点区间的时候,就不必向下更新了,而是直接修改节点的Lazy标记,
这样
我们回过头来看,树节点的设计中我们一直都留着一个add,这个就是每个树节点的Lazy标记

struct STNode {
    STNode* left;
    STNode* right;
    int val;
    int add; // Lazy标记
};

我们更新[0,6]这个区间,递归结束的节点始终满足节点区间被查询区间所包含,注意更新的值为:区间长度*val

但是别忘了最后的回溯更新,所以这次区间更新的最终结果是,黄色的线代表回溯更新的路线

详解线段树_第4张图片

向下传递

假设我们这时候需要查询或者更新[0,2]
我们这里以更新区间[0,2],让区间里面的每一个元素加1为例:

当我们遍历到区间[0,2]的时候,这时候就需要下放Lazy标记到他的子节点,注意这里下放的Lazy标记代表节点区间[0,2]和节点区间[3,5]理论上应该区间每一个元素+1,所以每个节点的val需要+3,然后两个子节点lazy标记需要加上父节点的Lazy标记(注意不是直接赋值),因为这两个子节点的子节点并没有更新Lazy标记。最后记得把父节点区间[0,5]的lazy标记置0!
详解线段树_第5张图片

然后由于我们还需要更新[0,2]节点的值,所以我们直接更新节点区间[0,2]lazy标记,让Lazy标记+=1(因为要更新区间[0,2]让区间每一个元素加1)
详解线段树_第6张图片

最后别忘了回溯的时候更新父节点的val!
详解线段树_第7张图片

代码

从上面的流程我们可以看出,Lazy标记实际上是一种向下传递,所以我们只需要修改PushDown函数的逻辑既可,那我们就要思考一个问题:何时需要向下传递?
答案很简单:当查询区间不包含当前节点区间的时候,这时候查询的区间一定在当前区间的子节点中,所以需要向下递归,所以在向下递归之前我们需要下传Lazy标记!
这句话需要好好理解一下。

   void PushDown(int l, int r, STNode* cur) {
       if (cur->add == 0)   // 如果当前节点没有Lazy标记,则不需要下传直接返回
           return;
       int mid = (r - l) / 2 + l;
       // 将父节点的 Lazy 标记下传到
       cur->left->add += cur->add;
       cur->right->add += cur->add;
       cur->left->val += (mid - l + 1) * cur->add;
       cur->right->val += (r - mid) * cur->add;
       cur->add = 0;
   }

然后相应的QueryUpdate函数都要做出相应的修改

   void Update(int left, int right, int val) {
       // left right 为要查询的区间   start end 为当前节点所维护的区间  val为区间修改的值
       function<void(int, int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, int val, STNode* cur) {
           if (left <= start && right >= end) {
               cur->val += (end - start + 1) * val;
               cur->add += val;
               return;
           }
           // 走到这里代表要查询的区间[left,right]一定在cur节点的子节点中! 一定要理解!!!!!
           PushDown(start, end, cur);
           int mid = (end - start) / 2 + start;
           int ansl = 0, ansr = 0;
           if (left > mid) {
               dfs(left, right, mid + 1, end, val, cur->right);
           }
           else if (right <= mid) {
               dfs(left, right, start, mid, val, cur->left);
           }
           else {
               dfs(left, right, mid + 1, end, val, cur->right);
               dfs(left, right, start, mid, val, cur->left);
           }
           PushUp(cur);   // 向上更新父节点
           };
       dfs(left, right, 0, N, val, root);
   }

   int Query(int left, int right) {
       // left right 为要查询的区间   start end 为当前节点所维护的区间
       function<int(int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, STNode* cur) {
           if (left <= start && right >= end) {
               return cur->val;
           }
            // 走到这里代表要查询的区间[left,right]一定在cur节点的子节点中! 一定要理解!!!!!
           PushDown(start, end, cur);
           int mid = (end - start) / 2 + start;
           int ansl = 0, ansr = 0;
           if (left > mid) {
               ansl = dfs(left, right, mid + 1, end, cur->right);
           }
           else if (right <= mid) {
               ansr = dfs(left, right, start, mid, cur->left);
           }
           else {
               ansl = dfs(left, right, mid + 1, end, cur->right);
               ansr = dfs(left, right, start, mid, cur->left);
           }
           PushUp(cur);
           return ansl + ansr;
           };
       return dfs(left, right, 0, N, root);
   }

动态开点

到这里实际上线段树就已经结束了,但是做题的时候如果这样搞理论上应该是无法通过的,因为题目的区间都是很大的,一般题目给的区间都是[1,1000000000],如果我们按照上面的构建线段树的方法,内存应该无法通过。与是这里我们又有了一种新的方法——动态开点

我们设想如果我们要查询的区间落在[1,1000000000]上,如果我们上来就开辟1000000000个节点构成线段树,设想下面这种情况,我们的区间查询和更新都集中在[0,5000000000] 上,而区间[500000001,1000000000]只被查询过一次,那么区间[500000001,1000000000]下面500000000个节点实际上是不用开辟的!
详解线段树_第8张图片

空间上是极大的浪费,我们这里的处理的方法和Lazy标记一样:就是当我们用到该节点的时候,我们再去创建该节点。
而我们何时知道需要访问到子节点?
也就是查询区间不包含当前节点区间的时候,意味着目标节点肯定在子节点中,这时候我们再去开辟子节点
我们把这部分逻辑放在PushDown函数里面去实现。


    void PushDown(int l, int r, STNode* cur) {
    	// 创建子节点 如果没有的话
        if (cur->left == nullptr) cur->left = new STNode{ nullptr,nullptr,0,0 };
        if (cur->right == nullptr) cur->right = new STNode{ nullptr,nullptr,0,0 };

        if (cur->add == 0)
            return;
        int mid = (r - l) / 2 + l;
        cur->left->add += cur->add;
        cur->right->add += cur->add;
        cur->left->val += (mid - l + 1) * cur->add;
        cur->right->val += (r - mid) * cur->add;
        cur->add = 0;
    }

完整模板

到这里线段树的模板就已经写完了

struct STNode {

    STNode* left;
    STNode* right;
    int val = 0;
    int add;
};



class SegmentTree {
public:

    void PushUp(STNode* cur) {
        cur->val = cur->left->val + cur->right->val;
    }

    SegmentTree() {
        N = 1000000000;
        root = new STNode{ nullptr,nullptr,0,0 };
    }


    void PushDown(int l, int r, STNode* cur) {
        if (cur->left == nullptr) cur->left = new STNode{ nullptr,nullptr,0,0 };
        if (cur->right == nullptr) cur->right = new STNode{ nullptr,nullptr,0,0 };

        if (cur->add == 0)
            return;
        int mid = (r - l) / 2 + l;
        cur->left->add += cur->add;
        cur->right->add += cur->add;
        cur->left->val += (mid - l + 1) * cur->add;
        cur->right->val += (r - mid) * cur->add;
        cur->add = 0;
    }


    void Update(int left, int right, int val) {
        // left right 为要查询的区间   start end 为当前节点所维护的区间  val为区间修改的值
        function<void(int, int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, int val, STNode* cur) {
            if (left <= start && right >= end) {
                cur->val += (end - start + 1) * val;
                cur->add += val;
                return;
            }
            PushDown(start, end, cur);
            int mid = (end - start) / 2 + start;
            int ansl = 0, ansr = 0;
            if (left > mid) {
                dfs(left, right, mid + 1, end, val, cur->right);
            }
            else if (right <= mid) {
                dfs(left, right, start, mid, val, cur->left);
            }
            else {
                dfs(left, right, mid + 1, end, val, cur->right);
                dfs(left, right, start, mid, val, cur->left);
            }
            PushUp(cur);
            };
        dfs(left, right, 0, N, val, root);
    }

    int Query(int left, int right) {
        // left right 为要查询的区间   start end 为当前节点所维护的区间
        function<int(int, int, int, int, STNode*)> dfs = [&](int left, int right, int start, int end, STNode* cur) {
            if (left <= start && right >= end) {
                return cur->val;
            }
            PushDown(start, end, cur);
            int mid = (end - start) / 2 + start;
            int ansl = 0, ansr = 0;
            if (left > mid) {
                ansl = dfs(left, right, mid + 1, end, cur->right);
            }
            else if (right <= mid) {
                ansr = dfs(left, right, start, mid, cur->left);
            }
            else {
                ansl = dfs(left, right, mid + 1, end, cur->right);
                ansr = dfs(left, right, start, mid, cur->left);
            }
            PushUp(cur);
            return ansl + ansr;
            };
        return dfs(left, right, 0, N, root);
    }


private:
    STNode* root = nullptr;   // 根节点
    int N;                    // 用于记录线段树查询区间的大小
};

总结

但是在做题的时候还是会对模板进行修改,修改的点主要在区间val的意义上,我们这里模板里面val代表的是区间和,但后面题目里面可能是 区间最大值 、区间最小值…,但是万变不离其宗,线段树的结构是不会变的。

例题:

我的日程安排表 1

我的日程安排表 II

我的日程安排表 III

715. Range 模块

统计区间中的整数数目

题目的详情 点击上面链接☝
详解线段树_第9张图片
这道题看完题目描述很明显是一道线段树的题目,题目要求我们每次返回被标记的节点的数量(一个节点不能被重复标记),我们可以统计区间元素的个数来解决这个问题, 但是我们很快发现了一个问题: 如果我们标记了一个区间的节点,但是下次需要标记的区间与该区间重合,如果只是简单的统计区间的数量 那么就会造成重叠区间元素被重复统计!

所以我们这题的解决思路是:

  • 我们还是统计区间的和,但是当查询区间包含节点区间的时候,这时候会出现三种情况:
    • 这个节点区间 从来没有被标记过,我们直接将节点的val设置为区间长度既可(代表这个区间被标记了),并设置Lazy标记
    • 这个节点区间 一部分被标记过,比方说 当前节点区间为[0,5]但是 前面标记过[0,2]区间,这时节点的val=2 ,但是我们当前需要标记[0,5],这时我们需要 将val设置为当前区间的长度 即另val=6,并设置Lazy标记
    • 这个区间 已经被标记过了,因为不能重复标记这个区间,所以不做处理
  • 当我们在PushDown函数里面处理Lazy标记的时候,只要Lazy标记不为0,就把当前区间的val置为节点区间的长度。

JAVA版本


class STNode{
    STNode(STNode l,STNode r,int val,int add){
        this.left=l;
        this.right=r;
        this.val=val;
        this.add=add;
    }
    STNode left;
    STNode right;
    int val=0;
    int add=0;

}
class CountIntervals {
    STNode root;
    int N=1000000000;
    public CountIntervals() {
        root=new STNode(null,null,0,0);
    }
    private void PushDown(STNode cur,int begin,int end){

        if(cur.left==null)
            cur.left=new STNode(null,null,0,0);
        if(cur.right==null)
            cur.right=new STNode(null,null,0,0);

        if(cur.add==0)
            return;

        int mid=(end-begin)/2+begin;

        cur.left.val=(mid-begin+1);  // Lazy标记向下传递到时候 ,只需要把val置为区间长度既可
        cur.right.val=(end-mid);
        cur.left.add++;
        cur.right.add++;
        cur.add=0;
    }
     private void PushUp(STNode cur){
        cur.val=cur.left.val+cur.right.val;
    }
    void Update(int left,int right,int begin,int end,STNode cur){
        if(left<=begin && right>=end){
            if(cur.val!=(end-begin+1)){
                cur.val=(end-begin+1);
                cur.add++;
            }
            return ;
        }

        int mid=(end-begin)/2+begin;

        PushDown(cur,begin,end);
        if(left>mid){
            Update(left,right,mid+1,end,cur.right);
        }else if(right<=mid){
            Update(left,right,begin,mid,cur.left);
        }else{
            Update(left,right,mid+1,end,cur.right);
            Update(left,right,begin,mid,cur.left);
        }
        PushUp(cur);
    }
    public void add(int left, int right) {
        Update(left,right,0,N,root);
    }
    public int count() {
        PushDown(root,0,N);
        PushUp(root);
        return root.val;
    }
}

C++版本



struct STNode {
    STNode* left;
    STNode* right;
    int val;
    int add;
};
class CountIntervals {
public:
    STNode* root = nullptr;

    void PushUp(STNode* cur) {
        cur->val = cur->left->val + cur->right->val;
    }
    void PushDown(STNode* cur, int begin, int end) {
        if (cur->left == nullptr)
            cur->left = new STNode{ nullptr,nullptr,0,0 };
        if (cur->right == nullptr)
            cur->right = new STNode{ nullptr,nullptr,0,0 };
        if (cur->add == 0)
            return;

        int mid = (end - begin) / 2 + begin;
        cur->left->val = (mid-begin+1);
        cur->right->val = (end - mid);
        cur->left->add += cur->add;
        cur->right->add += cur->add;
        cur->add = 0;
    }
    void Update(int left, int right, int begin, int end, STNode* cur) {
        if (left <= begin && right >= end) {
            if (cur->val != end - begin + 1) {
                cur->val = (end - begin + 1);
                cur->add += 1;
            }
            return;
        }

        PushDown(cur,begin,end);
        int mid = (end - begin) / 2 + begin;
        if (left > mid)
            Update(left, right, mid + 1, end, cur->right);
        else if (right <= mid)
            Update(left, right, begin, mid, cur->left);
        else {
            Update(left, right, mid + 1, end, cur->right);
            Update(left, right, begin, mid, cur->left);
        }
        PushUp(cur);
    }
    CountIntervals() {
        root = new STNode{ nullptr,nullptr,0,0 };
    }

    void add(int left, int right) {
        Update(left, right, 0, 1000000000, root);
        //cout<
    }

    int count() {
        PushDown(root,0,1000000000);
        root->val = root->left->val + root->right->val;
        return root->val;
    }
};

/**
 * Your CountIntervals object will be instantiated and called as such:
 * CountIntervals* obj = new CountIntervals();
 * obj->add(left,right);
 * int param_2 = obj->count();
 */

你可能感兴趣的:(算法,数据结构)