想一想更新的步骤,根据画图分析。从树的根开始更新,先把数据更新了,再更新 tree。set
方法 的设计与实现,其实是程式化的,这个过程熟悉了以后写起来,就会比较自然。最后不要忘记 merge 一下,从叶子结点开始,父辈结点,祖辈结点,直到根结点都要更新。
Java 代码:
public void set(int dataIndex, E val) {
if (dataIndex < 0 || dataIndex >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
data[dataIndex] = val;
set(0, 0, data.length - 1, dataIndex, val);
}
Java 代码:
private void set(int treeIndex, int l, int r, int dataIndex, E val) {
if (l == r) {
// 来到平衡二叉树的叶子点,这一步是最底层的更新操作
tree[treeIndex] = val;
return;
}
// 更新祖辈结点,还是先更新左边孩子和右边孩子,再更新
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
if (dataIndex >= mid + 1) {
// 到右边更新
set(rightTreeIndex, mid + 1, r, dataIndex, val);
}
if (dataIndex <= mid) {
// 到左边更新
set(leftTreeIndex, l, mid, dataIndex, val);
}
tree[treeIndex] = merge.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。
示例:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
说明:
方法一::基于前缀和数组的写法,效率不高。
说明:这道题如果采用前缀和数组的实现,会得到一个 TLE 的结果。但是采用线段树的实现,就能很容易通过。多看几遍,就明白是怎么回事了。
Java 代码:
// https://leetcode-cn.com/problems/range-sum-query-mutable/description/
public class NumArray {
private SegmentTree<Integer> segmentTree;
public NumArray(int[] nums) {
if(nums.length!=0){
Integer[] data = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
data[i] = nums[i];
}
segmentTree= new SegmentTree<>(data,(a,b)->a+b);
}
}
public void update(int i, int val) {
segmentTree.set(i,val);
}
public int sumRange(int i, int j) {
return segmentTree.query(i,j);
}
private interface Merge<E> {
E merge(E e1, E e2);
}
private class SegmentTree<E> {
private E[] tree;
private E[] data;
private Merge<E> merge;
public SegmentTree(E[] arr, Merge<E> merge) {
this.data = data;
this.merge = merge;
data = (E[]) new Object[arr.length];
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
tree = (E[]) new Object[4 * arr.length];
buildSegmentTree(0, 0, arr.length - 1);
}
private void buildSegmentTree(int treeIndex, int l, int r) {
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int mid = l + (r - l) / 2;
int leftChild = leftChild(treeIndex);
int rightChild = rightChild(treeIndex);
buildSegmentTree(leftChild, l, mid);
buildSegmentTree(rightChild, mid + 1, r);
tree[treeIndex] = merge.merge(tree[leftChild], tree[rightChild]);
}
public E query(int dataL, int dataR) {
if (dataL < 0 || dataL >= data.length || dataR < 0 || dataR >= data.length || dataL > dataR) {
throw new IllegalArgumentException("Index is illegal.");
}
return query(0, 0, data.length - 1, dataL, dataR);
}
private E query(int treeIndex, int l, int r, int dataL, int dataR) {
if (l == dataL && r == dataR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftChildIndex = leftChild(treeIndex);
int rightChildIndex = rightChild(treeIndex);
if (dataR <= mid) {
return query(leftChildIndex, l, mid, dataL, dataR);
}
if (dataL >= mid + 1) {
return query(rightChildIndex, mid + 1, r, dataL, dataR);
}
E leftResult = query(leftChildIndex, l, mid, dataL, mid);
E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, dataR);
return merge.merge(leftResult, rightResult);
}
public void set(int dataIndex, E val) {
if (dataIndex < 0 || dataIndex >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
data[dataIndex] = val;
set(0, 0, data.length - 1, dataIndex, val);
}
private void set(int treeIndex, int l, int r, int dataIndex, E val) {
if (l == r) {
tree[treeIndex] = val;
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
if (dataIndex >= mid + 1) {
set(rightTreeIndex, mid + 1, r, dataIndex, val);
}
if (dataIndex <= mid) {
set(leftTreeIndex, l, mid, dataIndex, val);
}
tree[treeIndex] = merge.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
public int getSize() {
return data.length;
}
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
return data[index];
}
public int leftChild(int index) {
return 2 * index + 1;
}
public int rightChild(int index) {
return 2 * index + 2;
}
}
}
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/
Python 代码:
class NumArray:
class SegmentTree:
def __init__(self, arr, merge):
self.data = arr
# 开 4 倍大小的空间
self.tree = [None for _ in range(4 * len(arr))]
if not hasattr(merge, '__call__'):
raise Exception('不是函数对象')
self.merge = merge
self.__build_segment_tree(0, 0, len(self.data) - 1)
def get_size(self):
return len(self.data)
def get(self, index):
if index < 0 or index >= len(self.data):
raise Exception("Index is illegal.")
return self.data[index]
def __left_child(self, index):
return 2 * index + 1
def __right_child(self, index):
return 2 * index + 2
def __build_segment_tree(self, tree_index, data_l, data_r):
# 区间只有 1 个数的时候,线段树的值,就是数组的值,不必做融合
if data_l == data_r:
self.tree[tree_index] = self.data[data_l]
# 不要忘记 return
return
# 然后一分为二去构建
mid = data_l + (data_r - data_l) // 2
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
self.__build_segment_tree(left_child, data_l, mid)
self.__build_segment_tree(right_child, mid + 1, data_r)
# 左右都构建好以后,再构建自己,因此是后续遍历
self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])
def __str__(self):
# 打印线段树
return str([str(ele) for ele in self.tree])
def query(self, data_l, data_r):
if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
raise Exception('Index is illegal.')
return self.__query(0, 0, len(self.data) - 1, data_l, data_r)
def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
# 一般而言,线段树区间肯定会大一些,所以会递归查询下去
# 如果要查询的线段树区间和数据区间完全吻合,把当前线段树索引的返回回去就可以了
if tree_l == data_l and tree_r == data_r:
return self.tree[tree_index]
mid = tree_l + (tree_r - tree_l) // 2
# 线段树的左右两个索引
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
# 因为构建时是这样
# self.__build_segment_tree(left_child, data_l, mid)
# 所以,如果右边区间不大于中间索引,就只须要在左子树查询就可以了
if data_r <= mid:
return self.__query(left_child, tree_l, mid, data_l, data_r)
# 同理,如果左边区间 >= mid + 1,就只用在右边区间找就好了
# self.__build_segment_tree(right_child, mid + 1, data_r)
if data_l >= mid + 1:
return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
# 横跨两边的时候,先算算左边,再算算右边
left_res = self.__query(left_child, tree_l, mid, data_l, mid)
right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
return self.merge(left_res, right_res)
def set(self, data_index, val):
if data_index < 0 or data_index >= len(self.data):
raise Exception('Index is illegal.')
# 先把数据更新了
self.data[data_index] = val
# 线段树的更新递归去完成
self.__set(0, 0, len(self.data) - 1, data_index, val)
def __set(self, tree_index, tree_l, tree_r, data_index, val):
if tree_l == tree_r:
# 注意:这里不能填 tree_l 或者 tree_r
self.tree[tree_index] = val
return
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
mid = tree_l + (tree_r - tree_l) // 2
if data_index >= mid + 1:
# 如果在右边,就只去右边更新
self.__set(right_child, mid + 1, tree_r, data_index, val)
if data_index <= mid:
# 如果在左边,就只去左边更新
self.__set(left_child, tree_l, mid, data_index, val)
# 左边右边都更新完以后,再更新自己
self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.size = len(nums)
if self.size:
self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: void
"""
if self.size:
self.st.set(i, val)
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
if self.size:
return self.st.query(i, j)
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)