【算法日积月累】18-高级数据结构:线段树

高级数据结构:线段树-1

“线段树”实现了高效的“数组区间查询”与“数组区间更新”

“线段树”(segment tree)又称“区间树”,是一个高级数据结构,应用的对象是“数组”。

先来看 LeetCode 第 303 题和 LeetCode 第 307 题。

LeetCode 第 303 题:区域和检索 - 数组不可变

传送门:303. 区域和检索 - 数组不可变

给定一个整数数组 nums,求出数组从索引 ij (ij) 范围内元素的总和,包含 i, j 两点。

示例:

给定 nums = [-2, 0, 3, -5, 2, -1],求和函数为 sumRange()

sumRange(0, 2) -> 1
sumRange(2, 5) -> -1
sumRange(0, 5) -> -3

说明:

  1. 你可以假设数组不可变。
  2. 会多次调用 sumRange 方法。

思路:我们可以设计一个前缀和数组 cumsum ,在查询的时候,只用 时间复杂度,不过在数组元素有频繁更新的时候,会导致性能下降,即这种方式不适用于 LeetCode 第 307 题。

Python 代码:

class NumArray:

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        self.size = len(nums)
        if self.size > 0:
            self.cumsum = [0 for _ in range(self.size + 1)]
            self.cumsum[1] = nums[0]
            for i in range(2, len(nums) + 1):
                self.cumsum[i] = self.cumsum[i - 1] + nums[i - 1]

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        if self.size > 0:
            return self.cumsum[j + 1] - self.cumsum[i]
        return 0

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(i,j)

Java 代码:

public class NumArray {
    // cumsum 实现
    // [1,2,3,4]
    // [1,3,6,10]

    private int[] nums;

    public NumArray(int[] nums) {
        this.nums = nums;
        for (int i = 1; i < nums.length; i++) {
            nums[i] = nums[i] + nums[i - 1];
        }
    }

    public int sumRange(int i, int j) {
        return nums[j] - (i - 1 < 0 ? 0 : nums[i - 1]);
    }

    public static void main(String[] args) {
        int[] nums = {1, 2, 3, 4};
        NumArray numArray = new NumArray(nums);
        int result = numArray.sumRange(2, 3);
        System.out.println(result);
    }
}

LeetCode 第 307 题:区域和检索 - 数组可修改

传送门:307. 区域和检索 - 数组可修改。

给定一个整数数组 nums,求出数组从索引 ij (ij) 范围内元素的总和,包含 i, j 两点。

update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。

示例:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

说明:

  1. 数组仅可以在 update 函数下进行修改。
  2. 你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。

如果我们不使用任何数据结构,每次求“区间和”,都会遍历这个区间里的所有元素。如果区间里包含的元素很多,并且查询次数很频繁,时间复杂度就接近 。如果我们使用线段树,就可以把时间复杂度降低到 。

这里要注意的是“线段树”解决的区间问题不涉及“添加”与“删除”操作,即“CURD”,我们只负责“U” 和 “R”。

使用“遍历”与使用“线段树”对于“区间更新”与“区间查询”操作的复杂度

遍历 线段树
区间查询
区间更新

说明:由于我们的线段树(区间树)采用平衡二叉树实现,因此 中的对数函数以 为底,即 。

“线段树”可以使用数组表示

以前我们学习过“堆”,并且知道“堆”是一棵“完全二叉树”,因此“堆”可以用数组表示,基于此,我们很自然地想到可以用数组表示“线段树”。

完全二叉树:除了最后一层以外,其余各层的结点数达到最大,并且最后一层所有的结点都连续地、集中地存储在最左边。

线段树虽然不是完全二叉树,但线段树是平衡二叉树,依然也可以用数组表示。

“自顶向下”递归构建线段树

首先看看“线段树”长什么样。

线段树是一种二叉树结构,不过在实现的时候,可以使用数组实现,这一点和优先队列是一致的。

高级数据结构:线段树-2

需要多少空间

“线段树”的一个经典实现是从上到下递归构建,这一点很像根据员工人数来定领导的人数,设置多少领导的个数就要看员工有多少人了。再想一想,我们在开篇对于线段树的介绍,线段树适合支持的操作是“查询”和“更新”,不适用于“添加”和“删除”。

下面以“员工和领导”为例,讲解从上到下逐步构建线段树的步骤:我们首先要解决的问题是“一共要设置多少领导”,我们宁可有一些位置没有人坐,也要让所有的人都坐下,因此我们在做估计的时候只会放大

高级数据结构:线段树-3

比较极端的一种情况:

高级数据结构:线段树-4
高级数据结构:线段树-5
高级数据结构:线段树-6

我们假设员工的人数为 ,我们也可以认为这就是是我们问题的规模,如果 可以表示成 (例如,、、), 是正整数,这种情况下,组织出来的数一定是满二叉树(除叶子结点外的所有结点均有两个子结点)。那么要设置的领导的人数就是 ,于是我们设置 长度的数组就一定可以容纳下这么多领导和员工。

下面考虑一种糟糕的情况,例如我们的员工人数刚刚好是 次方幂多 ,例如 、、,我们的思路很简单,看看可不可以转化成上面那种情况,因为满二叉树一定是完全二叉树,我们就可以使用数组来表示),原则仍然是放大,例如: “ 放大到 ”,“ 放大到 ”, 但是我们不这么做,我们做得再“过分”一点,我们放大到 倍,它一定比大于问题规模 的最小 次方幂还大,此时为了组织成完全二叉树,将问题规模放大到 ,由上面的分析,我们知道还要给领导准备 把椅子,那么总共领导和员工就要准备 把椅子。

线段树是一颗平衡二叉树

线段树是一棵平衡二叉树(最大深度和最小深度的差距最多为 )。平衡二叉树不会像二分搜索树那样变成一个链表,并且平衡二叉树也可以用数组来表示。

我们还要清楚一点,我们上面只是为了分析出,我们要处理问题规模为 的问题的时候,要准备多少空间,我们分析出当员工数为 的时候,最多分配到 把椅子就能把领导和员工都装下了。下面展示一些图来表示这些情况,特别注意,我们分析的时候是从下到上的,但是实际上,我们拿到问题规模以后的划分却是从上到下的。我们的确浪费了一些空间,甚至有的时候我们浪费了很多空间。

根据上面的讨论,我们可以写出线段树的框架:

Python 代码:

class SegmentTree:

    def __init__(self, arr):
        self.data = arr
        # 开 4 倍大小的空间
        self.tree = [None for _ in range(4 * len(arr))]

    def get_size(self):
        return len(self.data)

    def get(self, index):
        if index < 0 or index >= len(self.data):
            raise Exception("Index is illegal.")
        return self.data[index]

    def __left_child(self, index):
        return 2 * index + 1

    def __right_child(self, index):
        return 2 * index + 2

Java 代码:

public class SegmentTree {
    // 一共要给领导和员工准备的椅子,是我们要构建的辅助数据结构
    private E[] tree;
    // 原始的领导和员工数据,这是一个副本
    private E[] data;

    public SegmentTree(E[] arr) {
        this.data = data;
        // 数组初始化
        data = (E[]) new Object[arr.length];
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        tree = (E[]) new Object[4 * arr.length];
    }

    public int getSize() {
        return data.length;
    }

    public E get(int index) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("Index is illegal.");
        }
        return data[index];
    }

    /**
     * 返回完全二叉树的数组表示中,索引所表示的元素的左孩子结点的索引
     * 注意:索引编号从 0 开始
     *
     * @param 线段树的某个结点的索引
     * @return 传入的结点的左结点的索引
     */
    public int leftChild(int index) {
        return 2 * index + 1;
    }

    /**
     * 返回完全二叉树的数组表示中,索引所表示的元素的左孩子结点的索引
     * 注意:索引编号从 0 开始
     *
     * @param 线段树的某个结点的索引
     * @return 传入的结点的右结点的索引
     */
    public int rightChild(int index) {
        return 2 * index + 2;
    }

}

根据原始数组创建线段树

这一节的目标是:我们把员工的信息输入一棵线段树,让这棵线段树组织出领导架构。即已知 data 数组,要把 tree 数组构建出来。

分析递归结构

高级数据结构:线段树-7

重点体会:二叉树每做一次分支都是“平均地”一分为二进行的。

递归到底的时候,这个区间只有 个元素

设计私有函数,我们需要考虑 个变量:

1、我们要创建的线段树的根结点的索引,这个索引是线段树的索引;

2、对于线段树结点所要表示的 data 数组的区间的左端点是什么;

3、对于线段树结点所要表示的 data 数组的区间的右端点是什么。

Java 代码:

buildSegmentTree(0, 0, arr.length - 1);

Java 代码:关键代码

/**
 * 这个递归方法的描述一定要非常清楚:
 * 画出 tree 树中以 treeIndex 为根的,统计 data 数组中 [l,r] 区间中的元素
 * 这个方法的实现引入了一个 merge 接口,使得外部可以传入一个方法,方法是如何实现的是根据业务而定
 * 核心代码只有几行,这里关键还是在于递归方法
 *
 * @param treeIndex 我们要创建的线段树根结点所在的索引,treeIndex 是 tree 的索引
 * @param l         对于 treeIndex 结点所要表示的 data 区间端点是什么,l 是 data 的索引
 * @param r         对于 treeIndex 结点所要表示的 data 区间端点是什么,r 是 data 的索引
 */
private void buildSegmentTree(int treeIndex, int l, int r) {
    // 考虑递归到底的情况
    if (l == r) {
        // 平衡二叉树叶子结点的赋值就是靠这句话形成的
        tree[treeIndex] = data[l]; // data[r],此时对应叶子结点的情况
        return;// return 不能忘记
    }
    int mid = l + (r - l) / 2;
    int leftChild = leftChild(treeIndex);
    int rightChild = rightChild(treeIndex);
    // 假设左边右边都处理完了以后,再处理自己
    // 这一点基于,高层信息的构建依赖底层信息的构建
    // 这个递归的过程我们可以通过画图来理解
    // 仔细阅读下面的这三行代码,是不是像极了二分搜索树的后序遍历,我们先处理了左右孩子结点,最后处理自己
    buildSegmentTree(leftChild, l, mid);
    buildSegmentTree(rightChild, mid + 1, r);
    
    // 注意:merge 的实现根据业务而定
    tree[treeIndex] = merge.merge(tree[leftChild], tree[rightChild]);
}

Merge 接口的设计,这里使用传入对象的方式实现了方法传递,是 Command 设计模式。

Java 代码:

public interface Merge {
    E merge(E e1, E e2);
}

SegmentTree 覆盖 toString 方法,用于打印线段树表示的数组,以便执行测试用例。

@Override
public String toString() {
    StringBuilder s = new StringBuilder();
    s.append("[");
    for (int i = 0; i < tree.length; i++) {
        if(tree[i] == null){
            s.append("NULL");
        }else{
            s.append(tree[i]);
        }
        s.append(",");
    }
    s.append("]");
    return s.toString();
}

4、测试方法

public class Main {
    public static void main(String[] args) {
        Integer[] nums = {0, -1, 2, 4, 2};
        SegmentTree segmentTree = new SegmentTree(nums, new Merge() {
            @Override
            public Integer merge(Integer e1, Integer e2) {
                return e1 + e2;
            }
        });
        System.out.println(segmentTree);
    }
}

区间查询

通过编写二分搜索树的经验,我们知道,一些递归的写法通常要写一个辅助函数,在这个辅助函数里完成递归调用。那么对于这个问题中,辅助函数的设计就显得很关键了。

// 在一棵子树里做区间查询,dataL 和 dataR 都是原始数组的索引
public E query(int dataL, int dataR) {
    if (dataL < 0 || dataL >= data.length || dataR < 0 || dataR >= data.length || dataL > dataR) {
        throw new IllegalArgumentException("Index is illegal.");
    }
    // data.length - 1 边界不能弄错
    return query(0, 0, data.length - 1, dataL, dataR);
}

在这个辅助函数的实现过程中,可以画一张图来展现一下具体的计算过程。

高级数据结构:线段树-8
高级数据结构:线段树-9

体会下面这个过程:

我们总是自上而下,从根结点开始向下查询,最坏情况下,才会查询到叶子结点。

Java 代码:

// 这是一个递归调用的辅助方法,应该定义成私有方法
private E query(int treeIndex, int l, int r, int dataL, int dataR) {
    if (l == dataL && r == dataR) {
        // 这里一定不要犯晕,看图说话
        return tree[treeIndex];
    }
    int mid = l + (r - l) / 2;
    int leftChildIndex = leftChild(treeIndex);
    int rightChildIndex = rightChild(treeIndex);
    // 画个示意图就能清楚自己的逻辑是怎样的
    if (dataR <= mid) {
        return query(leftChildIndex, l, mid, dataL, dataR);
    }
    if (dataL >= mid + 1) {
        return query(rightChildIndex, mid + 1, r, dataL, dataR);
    }
    // 横跨两边的时候,先算算左边,再算算右边
    E leftResult = query(leftChildIndex, l, mid, dataL, mid);
    E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, dataR);
    return merge.merge(leftResult, rightResult);
}

LeetCode 第 303 题:区域和检索 - 数组不可变

传送门:303. 区域和检索 - 数组不可变

给定一个整数数组 nums,求出数组从索引 ij (ij) 范围内元素的总和,包含 i, j 两点。

示例:

给定 nums = [-2, 0, 3, -5, 2, -1],求和函数为 sumRange()

sumRange(0, 2) -> 1
sumRange(2, 5) -> -1
sumRange(0, 5) -> -3

说明:

  1. 你可以假设数组不可变。
  2. 会多次调用 sumRange 方法。

思路2:基于线段树(区间树)的实现。

Python 代码:

class NumArray:
    class SegmentTree:

        def __init__(self, arr, merge):
            self.data = arr
            # 开 4 倍大小的空间
            self.tree = [None for _ in range(4 * len(arr))]
            if not hasattr(merge, '__call__'):
                raise Exception('不是函数对象')
            self.merge = merge
            self.__build_segment_tree(0, 0, len(self.data) - 1)

        def get_size(self):
            return len(self.data)

        def get(self, index):
            if index < 0 or index >= len(self.data):
                raise Exception("Index is illegal.")
            return self.data[index]

        def __left_child(self, index):
            return 2 * index + 1

        def __right_child(self, index):
            return 2 * index + 2

        def __build_segment_tree(self, tree_index, data_l, data_r):
            # 区间只有 1 个数的时候,线段树的值,就是数组的值,不必做融合
            if data_l == data_r:
                self.tree[tree_index] = self.data[data_l]
                # 不要忘记 return
                return

            # 然后一分为二去构建
            mid = data_l + (data_r - data_l) // 2
            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)

            self.__build_segment_tree(left_child, data_l, mid)
            self.__build_segment_tree(right_child, mid + 1, data_r)

            # 左右都构建好以后,再构建自己,因此是后续遍历
            self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])

        def __str__(self):
            # 打印线段树
            return str([str(ele) for ele in self.tree])

        def query(self, data_l, data_r):
            if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
                raise Exception('Index is illegal.')
            return self.__query(0, 0, len(self.data) - 1, data_l, data_r)

        def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
            # 一般而言,线段树区间肯定会大一些,所以会递归查询下去
            # 如果要查询的线段树区间和数据区间完全吻合,把当前线段树索引的返回回去就可以了
            if tree_l == data_l and tree_r == data_r:
                return self.tree[tree_index]

            mid = tree_l + (tree_r - tree_l) // 2
            # 线段树的左右两个索引
            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)

            # 因为构建时是这样
            # self.__build_segment_tree(left_child, data_l, mid)
            # 所以,如果右边区间不大于中间索引,就只须要在左子树查询就可以了
            if data_r <= mid:
                return self.__query(left_child, tree_l, mid, data_l, data_r)
            # 同理,如果左边区间 >= mid + 1,就只用在右边区间找就好了
            # self.__build_segment_tree(right_child, mid + 1, data_r)
            if data_l >= mid + 1:
                return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
            # 横跨两边的时候,先算算左边,再算算右边
            left_res = self.__query(left_child, tree_l, mid, data_l, mid)
            right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
            return self.merge(left_res, right_res)

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        if len(nums) > 0:
            self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        if self.st is None:
            return 0
        return self.st.query(i, j)

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(i,j)

Java 代码:可以点击 这里 查看。

区间更新

想一想更新的步骤,根据画图分析。从树的根开始更新,先把数据更新了,再更新 tree。set方法 的设计与实现,其实是程式化的,这个过程熟悉了以后写起来,就会比较自然。最后不要忘记 merge 一下,从叶子结点开始,父辈结点,祖辈结点,直到根结点都要更新。

Java 代码:

public void set(int dataIndex, E val) {
    if (dataIndex < 0 || dataIndex >= data.length) {
        throw new IllegalArgumentException("Index is illegal.");
    }
    data[dataIndex] = val;
    set(0, 0, data.length - 1, dataIndex, val);
}

Java 代码:

private void set(int treeIndex, int l, int r, int dataIndex, E val) {
    if (l == r) {
        // 来到平衡二叉树的叶子点,这一步是最底层的更新操作
        tree[treeIndex] = val;
        return;
    }
    // 更新祖辈结点,还是先更新左边孩子和右边孩子,再更新
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    int mid = l + (r - l) / 2;
    if (dataIndex >= mid + 1) {
        // 到右边更新
        set(rightTreeIndex, mid + 1, r, dataIndex, val);
    }
    if (dataIndex <= mid) {
        // 到左边更新
        set(leftTreeIndex, l, mid, dataIndex, val);
    }
    tree[treeIndex] = merge.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

LeetCode 上第 307 号问题:区域和检索 - 数组可修改

传送门:307. 区域和检索 - 数组可修改。

给定一个整数数组 nums,求出数组从索引 ij (ij) 范围内元素的总和,包含 i, j 两点。

update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。

示例:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

说明:

  1. 数组仅可以在 update 函数下进行修改。
  2. 你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。

思路1:基于 cumsum 数组的写法,效率不高。

高级数据结构:线段树-10

说明:这道题如果采用 cumsum 数组的实现,会得到一个 TLE 的结果。但是采用线段树的实现,就能很容易通过。多看几遍,就明白是怎么回事了。

Python 代码:

class NumArray:
    class SegmentTree:

        def __init__(self, arr, merge):
            self.data = arr
            # 开 4 倍大小的空间
            self.tree = [None for _ in range(4 * len(arr))]
            if not hasattr(merge, '__call__'):
                raise Exception('不是函数对象')
            self.merge = merge
            self.__build_segment_tree(0, 0, len(self.data) - 1)

        def get_size(self):
            return len(self.data)

        def get(self, index):
            if index < 0 or index >= len(self.data):
                raise Exception("Index is illegal.")
            return self.data[index]

        def __left_child(self, index):
            return 2 * index + 1

        def __right_child(self, index):
            return 2 * index + 2

        def __build_segment_tree(self, tree_index, data_l, data_r):
            # 区间只有 1 个数的时候,线段树的值,就是数组的值,不必做融合
            if data_l == data_r:
                self.tree[tree_index] = self.data[data_l]
                # 不要忘记 return
                return

            # 然后一分为二去构建
            mid = data_l + (data_r - data_l) // 2
            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)

            self.__build_segment_tree(left_child, data_l, mid)
            self.__build_segment_tree(right_child, mid + 1, data_r)

            # 左右都构建好以后,再构建自己,因此是后续遍历
            self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])

        def __str__(self):
            # 打印线段树
            return str([str(ele) for ele in self.tree])

        def query(self, data_l, data_r):
            if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
                raise Exception('Index is illegal.')
            return self.__query(0, 0, len(self.data) - 1, data_l, data_r)

        def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
            # 一般而言,线段树区间肯定会大一些,所以会递归查询下去
            # 如果要查询的线段树区间和数据区间完全吻合,把当前线段树索引的返回回去就可以了
            if tree_l == data_l and tree_r == data_r:
                return self.tree[tree_index]

            mid = tree_l + (tree_r - tree_l) // 2
            # 线段树的左右两个索引
            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)

            # 因为构建时是这样
            # self.__build_segment_tree(left_child, data_l, mid)
            # 所以,如果右边区间不大于中间索引,就只须要在左子树查询就可以了
            if data_r <= mid:
                return self.__query(left_child, tree_l, mid, data_l, data_r)
            # 同理,如果左边区间 >= mid + 1,就只用在右边区间找就好了
            # self.__build_segment_tree(right_child, mid + 1, data_r)
            if data_l >= mid + 1:
                return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
            # 横跨两边的时候,先算算左边,再算算右边
            left_res = self.__query(left_child, tree_l, mid, data_l, mid)
            right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
            return self.merge(left_res, right_res)

        def set(self, data_index, val):
            if data_index < 0 or data_index >= len(self.data):
                raise Exception('Index is illegal.')
            # 先把数据更新了
            self.data[data_index] = val
            # 线段树的更新递归去完成
            self.__set(0, 0, len(self.data) - 1, data_index, val)

        def __set(self, tree_index, tree_l, tree_r, data_index, val):
            if tree_l == tree_r:
                # 注意:这里不能填 tree_l 或者 tree_r
                self.tree[tree_index] = val
                return

            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)
            mid = tree_l + (tree_r - tree_l) // 2

            if data_index >= mid + 1:
                # 如果在右边,就只去右边更新
                self.__set(right_child, mid + 1, tree_r, data_index, val)
            if data_index <= mid:
                # 如果在左边,就只去左边更新
                self.__set(left_child, tree_l, mid, data_index, val)
            # 左边右边都更新完以后,再更新自己
            self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])

    def __init__(self, nums):
        """
        :type nums: List[int]
        """

        self.size = len(nums)
        if self.size:
            self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)

    def update(self, i, val):
        """
        :type i: int
        :type val: int
        :rtype: void
        """
        if self.size:
            self.st.set(i, val)

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        if self.size:
            return self.st.query(i, j)

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)

Java 代码:可以点击 这里 查看。

“自底向上”的线段树实现

自底向上的线段树实现只要使用 倍原始数组大小的辅助空间就。下面的 2 张图就展示了这个过程:

我们根据结点个数的奇偶性,分别讨论,但是,最终我们发现,奇偶性并不影响结论。

我们从下到上构建二叉树:

1、先把原始结点做一个偏移,偏移量是原始数组的长度;

2、“自底向上”构建,即父节点就是该结点的索引值除以 ,这个除法是整数除法;

我们发现,不论是奇数个结点还是偶数个结点最终都可以达到根结点,并且根结点的索引是 ,索引是 的位置我们不用。

高级数据结构:线段树-11
高级数据结构:线段树-12

规律(不论结点个数是奇数还是偶数都成立):父结点的索引如果是 i,子结点的索引就是 2 * i2 * i + 1

Python 代码:

class SegmentTree:

    # 自底向上的线段树实现

    def __init__(self, arr, merge):
        self.data = arr
        self.size = len(arr)
        # 开 2 倍大小的空间
        self.tree = [None for _ in range(2 * self.size)]
        if not hasattr(merge, '__call__'):
            raise Exception('不是函数对象')
        self.merge = merge

        # 原始数值赋值
        for i in range(self.size, 2 * self.size):
            self.tree[i] = self.data[i - self.size]
        # 从后向前赋值

        for i in range(self.size - 1, 0, -1):
            self.tree[i] = self.merge(self.tree[2 * i], self.tree[2 * i + 1])

    def get_size(self):
        return len(self.data)

    def query(self, l, r):
        l += self.size
        r += self.size

        res = 0
        while l <= r:
            # 如果左端点是奇数
            if l & 1 == 1:
                if res == 0:
                    # 一开始要加上叶子结点
                    res = self.tree[l]
                else:
                    res = self.merge(res, self.tree[l])
                # 把左端点变成偶数
                l += 1
            if r & 1 == 0:
                if res == 0:
                    # 一开始要加上叶子结点
                    res = self.tree[r]
                else:
                    res = self.merge(res, self.tree[r])
                # 把右端点变成奇数
                r -= 1
            # 往叶子结点上走,所以是除以 2 
            l //= 2
            r //= 2
        return res

    def set(self, i, val):
        i += self.size

        self.tree[i] = val
        while i > 0:
            left = i
            right = i
            if i & 1 == 0:
                right = i + 1
            else:
                left = i - 1
            if left == 0:
                self.tree[i // 2] = self.tree[right]
            else:
                self.tree[i // 2] = self.merge(self.tree[left], self.tree[right])
            i //= 2


if __name__ == '__main__':
    nums = [-2, 0, 3, -5, 2, -1]
    st = SegmentTree(nums, lambda a, b: a + b)

    result1 = st.query(0, 2)
    print(result1)
    result2 = st.query(2, 5)
    print(result2)

    result3 = st.query(0, 5)
    print(result3)

Java 代码:

public class NumArray {

    private SegmentTree segmentTree;

    public NumArray(int[] nums) {
        Merger merger = new Merger() {
            @Override
            public Integer merge(Integer e1, Integer e2) {
                return e1 + e2;
            }
        };
        Integer[] arr = new Integer[nums.length];
        for (int i = 0; i < nums.length; i++) {
            arr[i] = nums[i];
        }
        segmentTree = new SegmentTree(arr, merger);
    }


    public void update(int i, int val) {
        segmentTree.set(i, val);
    }

    public int sumRange(int i, int j) {
        return segmentTree.query(i, j);
    }

    private interface Merger {
        E merge(E e1, E e2);
    }

    private class SegmentTree {

        private E[] tree;
        private int len;
        private Merger merger;

        private SegmentTree(E[] arr, Merger merger) {
            this.merger = merger;
            len = arr.length;
            tree = (E[]) new Object[len * 2];
            for (int i = len; i < 2 * len; i++) {
                tree[i] = arr[i - len];
            }
            for (int i = len - 1; i > 0; i--) {
                tree[i] = merger.merge(tree[2 * i], tree[2 * i + 1]);
            }
        }


        public E query(int l, int r) {
            l += len;
            r += len;
            E res = null;
            while (l <= r) {
                if (l % 2 == 1) {
                    if (res == null) {
                        res = tree[l];
                    } else {
                        res = merger.merge(res, tree[l]);
                    }
                    l++;
                }
                if (r % 2 == 0) {
                    if (res == null) {
                        res = tree[r];
                    } else {
                        res = merger.merge(res, tree[r]);
                    }
                    r--;
                }
                l /= 2;
                r /= 2;
            }
            return res;
        }

        public void set(int i, E val) {
            i += len;
            tree[i] = val;
            while (i > 0) {
                int left = i;
                int right = i;
                // i 是左边结点
                if (i % 2 == 0) {
                    right = i + 1;
                } else {
                    left = i - 1;
                }
                if (left == 0) {
                    tree[i / 2] = tree[right];
                } else {
                    tree[i / 2] = merger.merge(tree[left], tree[right]);

                }
                i /= 2;
            }
        }
    }
}

本文源代码

Python:代码文件夹,Java:代码文件夹。

参考资料

1、B 站上一位 UP 主的讲解:线段树入门。

博客地址:https://wmathor.com

(本节完)

你可能感兴趣的:(【算法日积月累】18-高级数据结构:线段树)