https://leetcode.com/problems/range-sum-query-mutable/
线段树的典型题目,参考http://bookshadow.com/weblog/2015/08/13/segment-tree-set-1-sum-of-given-range/
http://bookshadow.com/weblog/2015/11/18/leetcode-range-sum-query-mutable/
class NumArray(object):
def __init__(self, nums):
"""
initialize your data structure here.
:type nums: List[int]
"""
self.nums = nums
self.size = size = len(nums)
#这里计算segment tree高度。因为对于第i层(从1开始),其对应的node个数是2^(i-1)。这里假定最后一层node个数为size,反过来求i。
h = int(math.ceil(math.log(size, 2))) if size else 0
maxSize = 2 ** (h + 1) - 1#这里计算segment tree有多少个node
self.st = [0] * maxSize#这里用st来表示segment tree
if size:
self.initST(0, size - 1, 0)
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: int
"""
#这里update,只用update ST的一半。
if i < 0 or i >= self.size:
return
diff = val - self.nums[i]
self.nums[i] = val
self.updateST(0, self.size - 1, i, diff, 0)
def sumRange(self, i, j):
"""
sum of elements nums[i..j], inclusive.
:type i: int
:type j: int
:rtype: int
"""
if i < 0 or j < 0 or i >= self.size or j >= self.size:
return 0
return self.sumRangeST(0, self.size - 1, i, j, 0)
def initST(self, ss, se, si):
# [ss, se], si是st数组的index,这个function就是要给st数组赋值
# 这里父节点的值等于两个子节点的和
# 从root开始构造
if ss == se:
self.st[si] = self.nums[ss]
else:
mid = (ss + se) / 2
self.st[si] = self.initST(ss, mid, si * 2 + 1) + \
self.initST(mid + 1, se, si * 2 + 2)
return self.st[si]
def updateST(self, ss, se, i, diff, si):
#这里的i就是要update的nums中的元素的index,所以只要i是属于[ss,se],那么这个节点的sum值就要update, si 是线段树数组st的index,可以看做tree 的root
#这里依然是从root开始update,然后看是进入左子树update还是右子树update。
if i < ss or i > se:
return
# i在[ss,se]中间,所以要update。然后递归,找mid,还是看是进入左子树update还是右子树update。
self.st[si] += diff
if ss != se:#如果没到叶子节点。这里没有udpate st数组的叶子节点
mid = (ss + se) / 2
self.updateST(ss, mid, i, diff, si * 2 + 1)#注意这里不是mid - 1
self.updateST(mid + 1, se, i, diff, si * 2 + 2)
def sumRangeST(self, ss, se, qs, qe, si):# si 看成是root
if qs <= ss and qe >= se:# [ss, se]在[qs, qe]里面,就求得了[qs,qe]一部分区间的sum
return self.st[si]
if se < qs or ss > qe:#比方说q range[3,4], 总的range为[0,4], 那么对于左子树[0,2]就肯定对sum没贡献了。return 0
return 0
mid = (ss + se) / 2
#注意这里[qs, qe]一直没变。而且这里不是mid - 1
return self.sumRangeST(ss, mid, qs, qe, si * 2 + 1) + \
self.sumRangeST(mid + 1, se, qs, qe, si * 2 + 2)
class NumArray(object):
def __init__(self, nums):
"""
initialize your data structure here.
:type nums: List[int]
"""
self.nums = nums
self.size = size = len(nums)
h = int(math.ceil(math.log(self.size, 2))) if self.size else 0
maxSize = 2**(h + 1) -1
self.st = [0]*maxSize
if self.size:
self.initST(0, size - 1, 0)
def initST(self, ss, se, si):
#求ss-se的sum,并把sum给si指向的self.st
if ss == se:
self.st[si] = self.nums[ss]
else:
mid = (ss + se) / 2
a = self.initST(ss, mid, 2*si + 1)
b = self.initST(mid + 1, se, 2*si + 2)
self.st[si] = a + b
return self.st[si]#注意这里是要返回值的。
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: int
"""
if i < 0 or i >= self.size:
return
else:
diff = val - self.nums[i]
self.nums[i] = val
self.updateST(0, self.size - 1, i, diff, 0)
def updateST(self, ss, se, i, diff, si):
if i < ss or i > se:
return
else:
self.st[si] += diff
if ss != se:#这里会忘掉!!!要注意
mid = (ss + se) / 2
self.updateST(ss, mid, i, diff, 2*si + 1)
self.updateST(mid + 1, se, i, diff, 2*si + 2)
def sumRange(self, i, j):
"""
sum of elements nums[i..j], inclusive.
:type i: int
:type j: int
:rtype: int
"""
if i < 0 or j < 0 or i >= self.size or j >= self.size:#[i,j]边界有一个边界越界都不行
return 0
return self.sumRangeST(0, self.size - 1, i, j, 0)
def sumRangeST(self, ss, se, qs, qe, si):
if ss >= qs and se <=qe:
return self.st[si]
elif se < qs or ss > qe:
return 0
else:
mid = (ss + se)/2
return self.sumRangeST(ss, mid, qs, qe, 2*si+1) + \
self.sumRangeST(mid+1, se, qs, qe, 2*si+2)
# Your NumArray object will be instantiated and called as such:
# numArray = NumArray(nums)
# numArray.sumRange(0, 1)
# numArray.update(1, 10)
# numArray.sumRange(1, 2)
举例说明,对于range [0-9], 这里要求[3-6]的sum,模拟如下