10、线段树(segment_tree)

1、简介

线段树算法是一种快速查询一段区间内的信息的算法,由于其实现简单,所以广泛应用于程序设计竞赛中。
线段树是一棵完美二叉树,即所有的叶子节点的深度均相同, 并且所有的非叶子节点都有两个子节点。每个节点维护一个区间,这个区间为父节点二分后的子区间,根节点维护整个区间,叶子节点维护单个元素,当元素个数为n时,对区间的操作都可以在O(log n)的时间内完成,因为此时树的深度为log2 n + 1,每次操作只需从叶子节点开始,往上更新至根节点,每层只需更新相关的一个区间即可,操作次数log2 n + 1,即在O(log n)的时间内可完成。

2、可实现的功能

线段树可以提供不同的功能,例如最常见的求区间内的最大最小值和求区间内的和,还有其他类似的功能,实现思路基本相同。

3、线段树的操作

  • 线段树的查询: O(logn)

查询的区间 和 线段树节点区间 相等 -> 直接返回

查询的区间 被 线段树节点区间 包含-> 递归向下搜索左右子树

查询的区间 和 线段树节点区间 不相交 -> 结束

查询的区间 和 线段树节点区间 相交且不相等 -> 分裂查询区间

  • 线段树的建立: O(n)

自上而下递归分裂

自下而上回溯更新

  • 线段树的更新: O(logn)

自上而下递归查询

自下而上回溯更新

SegmentTreePython 实现:

# -*- coding: utf-8 -*-
"""
    Description: 线段树
    构建一个线段树,线段树的非叶子节点存储的是区间内的最大值
"""
import math


class SegmentTree(object):
    def __init__(self, arr):
        self.size = len(arr)
        # approximate the overall size of segment tree with array
        self.st = [0] * (4 * self.size)
        # build segment tree
        self.build(0, 0, self.size - 1, arr)

    def left(self, idx):
        """求索引为 idx 的节点的左子树的索引"""
        return idx * 2 + 1

    def right(self, idx):
        """求索引为 idx 的节点的右子树的索引"""
        return idx * 2 + 2

    def build(self, idx, left_idx, right_idx, arr):
        if left_idx == right_idx:
            self.st[idx] = arr[left_idx]
        else:
            mid = (left_idx + right_idx) // 2
            self.build(self.left(idx), left_idx, mid, arr)
            self.build(self.right(idx), mid + 1, right_idx, arr)
            # 非叶子节点中存储的是区间内的最大值
            self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)])

    def update(self, a, b, value):
        """将[a, b] 区间被的值更新为 value"""
        return self.update_recursive(0, 0, self.size - 1, a, b, value)

    def update_recursive(self, idx, left_idx, right_idx, a, b, value):
        if right_idx < a or left_idx > b:
            return True
        if left_idx == right_idx:
            self.st[idx] = value
            return True
        mid = (left_idx + right_idx) // 2
        self.update_recursive(self.left(idx), left_idx, mid, a, b, value)
        self.update_recursive(self.right(idx), mid + 1, right_idx, a, b, value)
        self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)])
        return True

    def query(self, a, b):
        return self.query_recursive(0, 0, self.size - 1, a, b)

    def query_recursive(self, idx, left_idx, right_idx, a, b):
        if right_idx < a or left_idx > b:
            return -math.inf
        if left_idx >= a and right_idx <= b:
            return self.st[idx]
        mid = (left_idx + right_idx) // 2
        q1 = self.query_recursive(self.left(idx), left_idx, mid, a, b)
        q2 = self.query_recursive(self.right(idx), mid + 1, right_idx, a, b)
        return max(q1, q2)

    def show_data(self):
        show_list = []
        for i in range(0, self.size):
            show_list += [self.query(i, i)]
        print(show_list)


if __name__ == '__main__':
    array = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9]
    print(len(array))
    segment_tree = SegmentTree(array)
    print(segment_tree.query(4, 6))
    print(segment_tree.query(7, 14))
    print(segment_tree.query(-3, 2))
    segment_tree.show_data()

    segment_tree.update(1, 3, 111)
    segment_tree.show_data()

    segment_tree.update(7, 8, 234)
    segment_tree.show_data()

你可能感兴趣的:(Python3,数据结构与算法)