数据结构 - 线段树的运用

数据结构 - 线段树的运用

  • 前言
  • 一. 线段树的运用
    • 1.1 区间和 - 线段树节点的成员变量
    • 1.2 线段树的构建
    • 1.3 线段树的区间和查询
    • 1.4 线段树的区间和更新
    • 1.5 完整代码
  • 二. 线段树的动态扩建
    • 2.1 向下递推
    • 2.2 向上递推
    • 2.3 更新操作
    • 2.4 查询操作
    • 2.5 完整代码
  • 三. 线段树的使用案例
    • 3.1 定长线段树的区间和计算
    • 3.2 动态线段树的区间和计算

前言

想要精通算法和SQL的成长之路 - 系列导航

一. 线段树的运用

我们先来看下线段树的含义:

  1. 线段树(Segment Tree):是一种解决 区间查询问题 的数据结构。
  2. 它将一个区间划分成多个较小的子区间,并对每个子区间计算出一些有用的信息,通常是该子区间的统计值(例如最大值、最小值、总和等)
  3. 通过将这些子区间的信息进行合并,线段树可以高效地回答各种区间查询问题。

线段树通常用于解决以下类型的问题:

  • 区间最值查询:找到给定区间内的最大值、最小值等。
  • 区间更新:对给定区间内的元素进行更新。
  • 区间求和:计算给定区间内元素的总和。

那么线段树的构建过程,通常是一个 分治递归 的过程。线段树的时间复杂度为O(logN)

例如我们有个数组:[1,2,3,4,5],它的区间和用线段树表示就是:
数据结构 - 线段树的运用_第1张图片

1.1 区间和 - 线段树节点的成员变量

首先我们思考一下,从上图中,我们一个节点需要包括哪些数据:

  • 左右节点。
  • 当前节点的左右区间。
  • 当前区间和。

那么不难得出,代码结构如下:

public class SegmentTreeNode {
    public SegmentTreeNode left;
    public SegmentTreeNode right;
    public int start;
    public int end;
    public int val;

    public SegmentTreeNode(int start, int end) {
        this.start = start;
        this.end = end;
    }
}

1.2 线段树的构建

我们往往针对的是一个数组,去构建它的线段树:

public class SegmentTree {
    // 线段树的根节点
    private SegmentTreeNode root;

    public SegmentTree(int[] nums) {
        root = buildTree(nums, 0, nums.length - 1);
    }

     private SegmentTreeNode buildTree(int[] nums, int start, int end) {
        if (start > end) {
            return null;
        }
        // 构建当前节点
        SegmentTreeNode node = new SegmentTreeNode(start, end);
        // 如果区间长度为1,那么当前节点的sum就是区间内的值
        if (start == end) {
            node.val = nums[start];
        } else {
            // 开始递归构建左右子树
            int mid = (start + end) >> 1;
            node.left = buildTree(nums, start, mid);
            node.right = buildTree(nums, mid + 1, end);
            // 当前节点的sum等于左右子树的sum之和
            node.val = node.left.val + node.right.val;
        }
        return node;
    }
}

1.3 线段树的区间和查询

构建完之后,我们需要计算区间和了:传入指定的区间queryStartqueryEnd,返回 [queryStart,queryEnd] 区间内的总和。

public int query(int queryStart, int queryEnd) {
    return queryHelper(root, queryStart, queryEnd);
}

private int queryHelper(SegmentTreeNode node, int queryStart, int queryEnd) {
    // 如果当前节点为空,或者当前节点的区间和不在查询区间内,那么返回0
    if (node == null || queryStart > node.end || queryStart < node.start) {
        return 0;
    }
    // 如果当前节点的区间完全在查询区间内,那么直接返回当前节点的sum。
    // 例如我们要查询[2,4] 的区间和,[2, 4] = [2, 2] + [3, 4]. 当前节点的区间是 [2,2] 或者 [3,4],所以直接返回当前节点的sum即可。
    if (node.start >= queryStart && node.end <= queryEnd) {
        return node.val;
    }
    // 否则,我们需要递归查询左右子树
    int mid = (node.start + node.end) >> 1;
    // 注意这里的Math.min和Math.max,因为我们的查询区间是[queryStart, queryEnd],而当前节点的区间是[node.start, node.end],所以我们需要取交集
    int leftSum = queryHelper(node.left, queryStart, Math.min(queryEnd, mid));
    int rightSum = queryHelper(node.right, Math.max(queryStart, mid + 1), queryEnd);
    return leftSum
}

1.4 线段树的区间和更新

上面的区间和计算,往往是基于数组的值不变的情况下进行的。那么假若数组中的某个元素被更新了,那么我们的区间和就不正确了,跟这个元素有关的所有链路,都要被更新,因此我们还需要准备一个更新的函数,它和查询非常类似,同样是递归操作。

public void update(int index, int newVal) {
    updateHelper(root, index, newVal);
}

private void updateHelper(SegmentTreeNode node, int index, int newVal) {
    if (node == null || index < node.start || index > node.end) {
        return;
    }
    // 如果当前节点的区间就是要更新的区间,那么直接更新当前节点的sum即可
    if (node.start == node.end) {
        node.val = newVal;
        return;
    }
    // 否则,我们需要递归更新左右子树
    int mid = (node.start + node.end) >> 1;
    if (index <= mid) {
        updateHelper(node.left, index, newVal);
    } else {
        updateHelper(node.right, index, newVal);
    }
    // 更新完左右子树之后,需要更新当前节点的sum
    node.val = node.left.val + node.right.val;
}

1.5 完整代码

public class SegmentTree {
    // 线段树的根节点
    private SegmentTreeNode root;

    public SegmentTree(int[] nums) {
        root = buildTree(nums, 0, nums.length - 1);
    }

    private SegmentTreeNode buildTree(int[] nums, int start, int end) {
        if (start > end) {
            return null;
        }
        // 构建当前节点
        SegmentTreeNode node = new SegmentTreeNode(start, end);
        // 如果区间长度为1,那么当前节点的sum就是区间内的值
        if (start == end) {
            node.val = nums[start];
        } else {
            // 开始递归构建左右子树
            int mid = (start + end) >> 1;
            node.left = buildTree(nums, start, mid);
            node.right = buildTree(nums, mid + 1, end);
            // 当前节点的sum等于左右子树的sum之和
            node.val = node.left.val + node.right.val;
        }
        return node;
    }

    public int query(int queryStart, int queryEnd) {
        return queryHelper(root, queryStart, queryEnd);
    }

    private int queryHelper(SegmentTreeNode node, int queryStart, int queryEnd) {
        // 如果当前节点为空,或者当前节点的区间和不在查询区间内,那么返回0
        if (node == null || queryStart > node.end || queryStart < node.start) {
            return 0;
        }
        // 如果当前节点的区间完全在查询区间内,那么直接返回当前节点的sum。
        // 例如我们要查询[2,4] 的区间和,[2, 4] = [2, 2] + [3, 4]. 当前节点的区间是 [2,2] 或者 [3,4],所以直接返回当前节点的sum即可。
        if (node.start >= queryStart && node.end <= queryEnd) {
            return node.val;
        }
        // 否则,我们需要递归查询左右子树
        int mid = (node.start + node.end) >> 1;
        // 注意这里的Math.min和Math.max,因为我们的查询区间是[queryStart, queryEnd],而当前节点的区间是[node.start, node.end],所以我们需要取交集
        int leftSum = queryHelper(node.left, queryStart, Math.min(queryEnd, mid));
        int rightSum = queryHelper(node.right, Math.max(queryStart, mid + 1), queryEnd);
        return leftSum + rightSum;
    }

    public void update(int index, int newVal) {
        updateHelper(root, index, newVal);
    }

    private void updateHelper(SegmentTreeNode node, int index, int newVal) {
        if (node == null || index < node.start || index > node.end) {
            return;
        }
        // 如果当前节点的区间就是要更新的区间,那么直接更新当前节点的sum即可
        if (node.start == node.end) {
            node.val = newVal;
            return;
        }
        // 否则,我们需要递归更新左右子树
        int mid = (node.start + node.end) >> 1;
        if (index <= mid) {
            updateHelper(node.left, index, newVal);
        } else {
            updateHelper(node.right, index, newVal);
        }
        // 更新完左右子树之后,需要更新当前节点的sum
        node.val = node.left.val + node.right.val;
    }
}

二. 线段树的动态扩建

第一节的代码,它有一个前提:

  • 我们已知数组的长度和各个元素的值。

那如果我们不知道数组包含的元素个数以及各个元素值的时候,怎么去建立这颗线段树?如果可以,它必定是在不断地更新的基础上去动态扩建的。

还是以求区间和为例,我们试想一下,在动态扩建的过程中,每个新节点需要向其他节点传递什么信息?

  • 本次应该新增的值,我们用一个变量add来标识。
  • 同时,由于数组的范围不再是固定,因此数据结构中应该剔除startend属性。
public static class SegmentTreeNode {
    public SegmentTreeNode left;
    public SegmentTreeNode right;
    public int add;
    public int val;
    public SegmentTreeNode() {}
}

2.1 向下递推

我们定义一个pushDown函数,它的功能有这么几个:

  • 动态创建左右子节点。
  • 给左右子节点传递add值,以及计算他们的区间和val
  • 若传递结束,那么要将当前节点的add值置为0。

考虑到add值的传递,以及树中叶子节点的性质,除了当前节点node我们还需要两个变量来标识当前节点的左孩子数量和右孩子数量。

代码如下:

private void pushDown(SegmentTreeNode node, int leftNum, int rightNum) {
    // 动态开点
    if (node.left == null) {
        node.left = new SegmentTreeNode();
    }
    if (node.right == null) {
        node.right = new SegmentTreeNode();
    }
    // 如果当前节点的add值为0,那么我们不需要更新子节点的add值
    if (node.val == 0) {
        return;
    }
    // 否则,更新左右子节点的add值
    node.left.add += node.add * leftNum;
    node.right.add += node.add * rightNum;
    // 更新左右子节点的add值
    node.left.val += node.add;
    node.right.val += node.add;
    // 更新当前节点的add值
    node.add = 0;
}

2.2 向上递推

我们定义一个函数pushUp,主要用来计算当前节点的区间值:

  • 当前节点的区间值 = 左区间和 + 右区间和。
private void pushUp(SegmentTreeNode node) {
    node.val = node.left.val + node.right.val;
}

2.3 更新操作

/**
* @param node  线段树的根节点
 * @param start 线段树的起始位置
 * @param end   线段树的结束位置
 * @param left  查询区间的左边界
 * @param right 查询区间的右边界
 * @param addValue   比原本的值多加的值
 */
public void update(SegmentTreeNode node, int start, int end, int left, int right, int addValue) {
    // 如果线段树的区间完全在查询区间内,那么直接更新当前节点的add值即可
    if (start >= left && end <= right) {
        // 该区间内,所有叶子节点都要加上val
        node.val += (end - start + 1) * addValue;
        // 该区间内,所有非叶子节点都要加上val,传递给后面的新节点
        node.add += addValue;
        return;
    }
    // 如果不在查询区间内,那么我们需要递归更新左右子树
    int mid = (start + end) >> 1;
    // 向下传递标记
    pushDown(node, mid - start + 1, end - mid);
    if (left <= mid) {
        update(node.left, start, mid, left, right, addValue);
    }
    // [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
    if (right > mid) {
        update(node.right, mid + 1, end, left, right, addValue);
    }
    // 计算当前节点的val值
    pushUp(node);
}

2.4 查询操作

public int query(SegmentTreeNode node, int start, int end, int left, int right) {
    if (left <= start && end <= right) {
        return node.val;
    }
    // 把当前区间 [start, end] 均分得到左右孩子的区间范围
    int mid = (start + end) >> 1, ans = 0;
    // 下推标记
    pushDown(node, mid - start + 1, end - mid);
    // [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
    if (left <= mid) {
        ans += query(node.left, start, mid, left, right);
    }
    // [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
    if (right > mid) {
        ans += query(node.right, mid + 1, end, left, right);
    }
    return ans;
}

2.5 完整代码

public class SegmentTreeDynamic {
    public static class SegmentTreeNode {
        public SegmentTreeNode left;
        public SegmentTreeNode right;
        public int add;
        public int val;

        public SegmentTreeNode() {
        }
    }

    /**
     * @param node     线段树的根节点
     * @param leftNum  左节点的叶子数量
     * @param rightNum 右节点的叶子数量
     */
    private void pushDown(SegmentTreeNode node, int leftNum, int rightNum) {
        // 动态开点
        if (node.left == null) {
            node.left = new SegmentTreeNode();
        }
        if (node.right == null) {
            node.right = new SegmentTreeNode();
        }
        // 如果当前节点的add值为0,那么我们不需要更新子节点的add值
        if (node.val == 0) {
            return;
        }
        // 否则,更新左右子节点的add值
        node.left.add += node.add * leftNum;
        node.right.add += node.add * rightNum;
        // 更新左右子节点的add值
        node.left.val += node.add;
        node.right.val += node.add;
        // 更新当前节点的add值
        node.add = 0;
    }

    private void pushUp(SegmentTreeNode node) {
        node.val = node.left.val + node.right.val;
    }

    /**
     * @param node  线段树的根节点
     * @param start 线段树的起始位置
     * @param end   线段树的结束位置
     * @param left  查询区间的左边界
     * @param right 查询区间的右边界
     * @param addValue   比原本的值多加的值
     */
    public void update(SegmentTreeNode node, int start, int end, int left, int right, int addValue) {
        // 如果线段树的区间完全在查询区间内,那么直接更新当前节点的add值即可
        if (start >= left && end <= right) {
            // 该区间内,所有叶子节点都要加上val
            node.val += (end - start + 1) * addValue;
            // 该区间内,所有非叶子节点都要加上val,传递给后面的新节点
            node.add += addValue;
            return;
        }
        // 如果不在查询区间内,那么我们需要递归更新左右子树
        int mid = (start + end) >> 1;
        // 向下传递标记
        pushDown(node, mid - start + 1, end - mid);
        if (left <= mid) {
            update(node.left, start, mid, left, right, addValue);
        }
        // [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
        if (right > mid) {
            update(node.right, mid + 1, end, left, right, addValue);
        }
        // 计算当前节点的val值
        pushUp(node);
    }

    public int query(SegmentTreeNode node, int start, int end, int left, int right) {
        if (left <= start && end <= right) {
            return node.val;
        }
        // 把当前区间 [start, end] 均分得到左右孩子的区间范围
        int mid = (start + end) >> 1, ans = 0;
        // 下推标记
        pushDown(node, mid - start + 1, end - mid);
        // [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
        if (left <= mid) {
            ans += query(node.left, start, mid, left, right);
        }
        // [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
        if (right > mid) {
            ans += query(node.right, mid + 1, end, left, right);
        }
        return ans;
    }
}

三. 线段树的使用案例

Demo如下:数组:[1, 3, 5, 7, 9, 11]

  1. 求得区间[1,4]的区间和。
  2. 如果索引为2的地方的值更新为19,求得区间[1,4]的区间和。

3.1 定长线段树的区间和计算

public static void main(String[] args) {
    int[] nums = {1, 3, 5, 7, 9, 11};
    SegmentTree segmentTree = new SegmentTree(nums);
    int sum = segmentTree.query(1, 4); // 查询区间[1, 4]的和 3+5+7+9=24
    System.out.println(sum); // 输出:24

    segmentTree.update(2, 19); // 将索引2处的值更新为19
    sum = segmentTree.query(1, 4); // 再次查询区间[1, 4]的和
    System.out.println(sum); // 输出:38
}

结果如下:
数据结构 - 线段树的运用_第2张图片

3.2 动态线段树的区间和计算

public static void main(String[] args) {
    int[] nums = {1, 3, 5, 7, 9, 11};
    SegmentTreeDynamic segmentTree = new SegmentTreeDynamic();
    SegmentTreeDynamic.SegmentTreeNode root = new SegmentTreeDynamic.SegmentTreeNode();
    int n = nums.length - 1;
    for (int i = 0; i < nums.length; i++) {
        segmentTree.update(root, 0, n, i, i, nums[i]);
    }

    System.out.println(segmentTree.query(root, 0, n, 1, 4));
    segmentTree.update(root, 0, n, 2, 2, 14);// 在原本值的基础上再多14,相当于定长计算中的5 + 14 = 19
    System.out.println(segmentTree.query(root, 0, n, 1, 4));
}

你可能感兴趣的:(精通算法和SQL之路,数据结构,java,算法)