[算法]线段树的python实现

线段树的python实现

求区间和,区间最值等

博客

博客1: 线段树 从入门到进阶

博客2:线段树详解

博客3

题目地址

落谷:

P3374 【模板】树状数组 1

P3368 【模板】树状数组 2

基本概念

  1. 什么是线段树
    线段树是一种二叉搜索树,什么叫做二叉搜索树,首先满足二叉树,每个结点度小于等于二,即每个结点最多有两颗子树,何为搜索,我们要知道,线段树的每个结点都存储了一个区间,也可以理解成一个线段,而搜索,就是在这些线段上进行搜索操作得到你想要的答案。

  2. 线段树能够解决什么样的问题
    线段树的适用范围很广,可以在线维护修改以及查询区间上的最值.每次更新以及查询的时间复杂度为O(logN)

  3. 线段树的空间使用
    在优化之前,线段树是有空间没有使用的,因此需要4*n(n为数组大小)个节点来储存数据

  4. 线段树的下标关系
    因为他是一棵完全二叉树,因此满足如下下标关系(下标从1开始)

    1. l = fa*2 (左子树下标为父亲下标的两倍)
    2. r = fa*2+1(右子树下标为父亲下标的两倍+1)

    使用位运算的话是如下的表示关系

    1. k<<1(结点k的左子树下标) : k*2
    2. 2k<<1|(结点k的右子树下标): k*2+1
  5. 线段树的基本操作

    • 单点更新
    • 区间查询
    • 区间更新

python代码

P3374 【模板】树状数组 1

注: 由于要开4倍空间,因此有两个用例内存限制过不了

单点更新和区间查询

理解

区间查询:

1、如果这个区间被完全包括在目标区间里面,直接返回这个区间的值

2、如果这个区间的左儿子和目标区间有交集,那么搜索左儿子,直到满足条件1为止

3、如果这个区间的右儿子和目标区间有交集,那么搜索右儿子,直到满足条件1为止

点单更新:

找到要更新的叶子节点,根据访问路径,由叶子节点向跟节点更新路径上节点的值,

# 线段树的节点类
class TreeNode(object):
    def __init__(self):
        self.left = -1
        self.right = -1
        self.sum_num = 0

    # 打印函数
    def __str__(self):
        return '[%s,%s,%s]' % (self.left, self.right, self.sum_num)

    # 打印函数
    def __repr__(self):
        return '[%s,%s,%s]' % (self.left, self.right, self.sum_num)


# 线段树类
# 以_开头的是递归实现
class Tree(object):
    def __init__(self, n, arr):
        self.n = n
        self.max_size = 4 * n
        self.tree = [TreeNode() for i in range(self.max_size)]  # 维护一个TreeNode数组
        self.arr = arr

    # index从1开始
    def _build(self, index, left, right):
        self.tree[index].left = left
        self.tree[index].right = right
        if left == right:
            self.tree[index].sum_num = self.arr[left - 1]
        else:
            mid = (left + right) // 2
            self._build(index * 2, left, mid)
            self._build(index * 2 + 1, mid + 1, right)
            self.pushup_sum(index)

    # 构建线段树
    def build(self):
        self._build(1, 1, self.n)

    def _update(self, point, val, i, l, r, ):
        if self.tree[i].left == self.tree[i].right:
            self.tree[i].sum_num += val
        else:
            mid = (l + r) // 2
            if point <= mid:
                self._update(point, val, i * 2, l, mid)
            else:
                self._update(point, val, i * 2 + 1, mid + 1, r)
                # 根据左右子树更新当前的值
            self.pushup_sum(i)

    # 单点更新
    # point 要更新的数在数组的下标 val更新的值
    def update(self, point, val, ):
        self._update(point, val, 1, 1, self.n)

    # 求和
    def pushup_sum(self, k):
        self.tree[k].sum_num = self.tree[k * 2].sum_num + self.tree[k * 2 + 1].sum_num

    def _query(self, ql, qr, i, l, r, ):
        if l >= ql and r <= qr:  # 若当前范围包含于要查询的范围
            return self.tree[i].sum_num
        else:
            mid = (l + r) // 2
            res_l = 0
            res_r = 0
            if ql <= mid:  # 左子树最大的值大于了查询范围最小的值-->左子树和需要查询的区间交集非空
                res_l = self._query(ql, qr, i * 2, l, mid, )
            if qr > mid:  # 右子树最小的值小于了查询范围最大的值-->右子树和需要查询的区间交集非空
                res_r = self._query(ql, qr, i * 2 + 1, mid + 1, r, )
            return res_l + res_r

    # 区间查询
    def query(self, ql, qr):
        return self._query(ql, qr, 1, 1, self.n)

    # 深度遍历打印数组
    def _show_arr(self, i):
        if self.tree[i].left == self.tree[i].right and self.tree[i].left != -1:
            print(self.tree[i].sum_num, end=" ")
        if 2 * i < len(self.tree):
            self._show_arr(i * 2)
            self._show_arr(i * 2 + 1)

    # 显示更新后的数组的样子
    def show_arr(self, ):
        self._show_arr(1)

# 落谷测试用例1
def test():
    n = 5  # 1 5 4 2 3
    arr = [1, 5, 4, 2, 3]
    tree = Tree(n, arr)
    tree.build()
    tree.update(1, 3)
    res = tree.query(2, 5)
    print(res)
    tree.update(3, -1)
    tree.update(4, 2)
    res = tree.query(1, 4)
    print(res)


if __name__ == '__main__':
    # test()  样例输出
    line1 = [int(x) for x in input().strip().split(" ")]
    n = line1[0]  # 数字的个数
    m = line1[1]  # 操作的个数
    arr = [int(x) for x in input().strip().split(" ")]
    tree = Tree(n, arr)
    tree.build()
    for i in range(m):
        line = [int(x) for x in input().split(" ")]
        op = line[0]
        if op == 1:
            tree.update(line[1], line[2])
        elif op == 2:
            res = tree.query(line[1], line[2])
            print(res)

区间更新和(单点)区间查询

在原有的基础上,添加一个lazy_tag变量和push_down函数.

  1. 为什么需要Lazy_tag

    因为对于一个区间[L,R]来说,我们可能每次都更新区间中的每个值,那样的话更新的复杂度将会是O(NlogN),很高.

  2. Lazy_tag的原理

    1. 每次更新只更新到更新区间完全覆盖线段树结点区间为止,但这样就会导致被更新结点的子孙结点的区间得不到需要更新的信息,

    2. 所以在被更新结点上打上一个标记,称为lazy-tag,等到下次访问这个结点的子结点时再将这个标记传递给子结点,所以也可以叫延迟标记

    3. 也就是说递归更新的过程,更新到结点区间为目标区间的真子集不再往下更新

      下次若是遇到需要用该节点下面的结点的信息,再去更新该节点下面的结点,所以这样的话使得区间更新的操作和区间查询类似,复杂度为O(logN)

理解

区间查询:

1、如果当前区间被完全覆盖在目标区间里,讲这个区间的sum+=k*(tree[i].r-tree[i].l+1)

2、如果没有完全覆盖,则先下传懒标记(lazy_tag)

3、如果这个区间的左儿子和目标区间有交集,那么搜索左儿子

4、如果这个区间的右儿子和目标区间有交集,那么搜索右儿子

将区间查询的左右区间设为一个值就是单点查询了

区间更新,单点查询: P3368 【模板】树状数组 2

# 线段树的节点类
class TreeNode(object):
    def __init__(self):
        self.left = -1
        self.right = -1
        self.sum_num = 0
        self.lazy_tag = 0

    # 打印函数
    def __str__(self):
        return '[%s,%s,%s,%s]' % (self.left, self.right, self.sum_num, self.lazy_tag)

    # 打印函数
    def __repr__(self):
        return '[%s,%s,%s,%s]' % (self.left, self.right, self.sum_num, self.lazy_tag)


# 线段树类
# 以_开头的是递归实现
class Tree(object):
    def __init__(self, n, arr):
        self.n = n
        self.max_size = 4 * n
        self.tree = [TreeNode() for i in range(self.max_size)]  # 维护一个TreeNode数组
        self.arr = arr

    # index从1开始
    def _build(self, index, left, right):
        self.tree[index].left = left
        self.tree[index].right = right
        if left == right:
            self.tree[index].sum_num = self.arr[left - 1]
        else:
            mid = (left + right) // 2
            self._build(index * 2, left, mid)
            self._build(index * 2 + 1, mid + 1, right)
            self.pushup_sum(index)

    # 构建线段树
    def build(self):
        self._build(1, 1, self.n)

    def _update2(self, ql, qr, val, i, l, r, ):
        mid = (l + r) // 2
        if l >= ql and r <= qr:
            self.tree[i].sum_num += (r - l + 1) * val  # 更新和
            self.tree[i].lazy_tag += val  # 更新懒惰标记
        else:
            self.pushdown_sum(i)
            if mid >= ql:
                self._update2(ql, qr, val, i * 2, l, mid)
            if qr > mid:
                self._update2(ql, qr, val, i * 2 + 1, mid + 1, r)
            self.pushup_sum(i)

    # 区间修改
    def update2(self, ql, qr, val, ):
        self._update2(ql, qr, val, 1, 1, self.n)

    def _query2(self, ql, qr, i, l, r, ):
        if l >= ql and r <= qr:  # 若当前范围包含于要查询的范围
            return self.tree[i].sum_num
        else:
            self.pushdown_sum(i)  # modify
            mid = (l + r) // 2
            res_l = 0
            res_r = 0
            if ql <= mid:  # 左子树最大的值大于了查询范围最小的值-->左子树和需要查询的区间交集非空
                res_l = self._query2(ql, qr, i * 2, l, mid, )
            if qr > mid:  # 右子树最小的值小于了查询范围最大的值-->右子树和需要查询的区间交集非空
                res_r = self._query2(ql, qr, i * 2 + 1, mid + 1, r, )
            return res_l + res_r

    def query2(self, ql, qr):
        return self._query2(ql, qr, 1, 1, self.n)

    # 求和,向上更新
    def pushup_sum(self, k):
        self.tree[k].sum_num = self.tree[k * 2].sum_num + self.tree[k * 2 + 1].sum_num

    # 向下更新lazy_tag
    def pushdown_sum(self, i):
        lazy_tag = self.tree[i].lazy_tag
        if lazy_tag != 0:  # 如果有lazy_tag
            self.tree[i * 2].lazy_tag += lazy_tag  # 左子树加上lazy_tag
            self.tree[i * 2].sum_num += (self.tree[i * 2].right - self.tree[i * 2].left + 1) * lazy_tag  # 左子树更新和
            self.tree[i * 2 + 1].lazy_tag += lazy_tag  # 右子树加上lazy_tag
            self.tree[i * 2 + 1].sum_num += (self.tree[i * 2 + 1].right - self.tree[
                i * 2 + 1].left + 1) * lazy_tag  # 右子树更新和
            self.tree[i].lazy_tag = 0  # 将lazy_tag 归0

    # 深度遍历
    def _show_arr(self, i):
        if self.tree[i].left == self.tree[i].right and self.tree[i].left != -1:
            print(self.tree[i].sum_num, end=" ")
        if 2 * i < len(self.tree):
            self._show_arr(i * 2)
            self._show_arr(i * 2 + 1)

    # 显示更新后的数组的样子
    def show_arr(self, ):
        self._show_arr(1)

    def __str__(self):
        return str(self.tree)

# 落谷测试用例1
def test():
    n = 5  # 1 5 4 2 3
    arr = [1, 5, 4, 2, 3]
    tree = Tree(n, arr)
    tree.build()
    tree.update2(2, 4, 2)
    # # print(tree)
    res = tree.query2(3, 3)
    # print(tree)
    print(res)
    tree.update2(1, 5, -1)
    tree.update2(3, 5, 7)
    res = tree.query2(4, 4)
    print(res)


if __name__ == '__main__':
    # 样例输出
    line1 = [int(x) for x in input().strip().split(" ")]
    n = line1[0]  # 数字的个数
    m = line1[1]  # 操作的个数
    arr = [int(x) for x in input().strip().split(" ")]
    tree = Tree(n, arr)
    tree.build()
    for i in range(m):
        line = [int(x) for x in input().split(" ")]
        op = line[0]
        if op == 1:
            tree.update2(line[1], line[2], line[3])
        elif op == 2:
            res = tree.query2(line[1], line[1])
            print(res)

区间更新,区间查询:洛谷线段树模板1

代码同上,修改下main函数即可

if __name__ == '__main__':
    # 样例输出
    line1 = [int(x) for x in input().strip().split(" ")]
    n = line1[0]  # 数字的个数
    m = line1[1]  # 操作的个数
    arr = [int(x) for x in input().strip().split(" ")]
    tree = Tree(n, arr)
    tree.build()
    for i in range(m):
        line = [int(x) for x in input().split(" ")]
        op = line[0]
        if op == 1:
            tree.update2(line[1], line[2], line[3])
        elif op == 2:
            res = tree.query2(line[1], line[2]) # 修改部分
            print(res)

你可能感兴趣的:(算法)