在上次打劫完一条街道之后和一圈房屋后,小偷又发现了一个新的可行窃的地区。这个地区只有一个入口,我们称之为“根”。 除了“根”之外,每栋房子有且只有一个“父“房子与之相连。一番侦察之后,聪明的小偷意识到“这个地方的所有房屋的排列类似于一棵二叉树”。 如果两个直接相连的房子在同一天晚上被打劫,房屋将自动报警。
计算在不触动警报的情况下,小偷一晚能够盗取的最高金额。
示例 1:
输入: [3,2,3,null,3,null,1]
3-
/ \
2 3
\ \
3- 1-
输出: 7
解释: 小偷一晚能够盗取的最高金额 = 3 + 3 + 1 = 7.
示例 2:
输入: [3,4,5,1,3,null,1]
3
/ \
4- 5-
/ \ \
1 3 1
输出: 9
解释: 小偷一晚能够盗取的最高金额 = 4 + 5 = 9.
解题思路
这是之前问题Leetcode 198:打家劫舍(最详细的解法!!!) 和 Leetcode 213:打家劫舍 II(最详细的解法!!!) 的拓展。
思路和之前一样,我们遍历整棵二分搜索树,对于每个节点,我们都需要判断这个节点的值我们需不需要取。例如
3<-
/ \
2 3
\ \
3 1
如果我们取3
,那么3.left
和3.right
我们都不能取了;如果我们不取,那么最大值来自于左右孩子的和,也就是3.left+3.right
,但是对于左右孩子来说又涉及到了上面的问题,我们是取还是不取呢?这样循环递推下去,直到节点为null
,我们直接返回0
即可。
定义函数 f ( r o o t ) f(root) f(root)表示以root
为根节点的最优解,那么:
其中root
的左右孩子为l
和r
,左孩子l
的左右孩子为ll
和lr
,右孩子r
的左右孩子为rl
和rr
。对于每个节点需要分别记录被抢的结果和未被抢的结果。
from functools import lru_cache
class Solution:
def rob(self, root):
"""
:type root: TreeNode
:rtype: int
"""
@lru_cache(None)
def dfs(root):
if not root:
return 0
res = root.val
if root.left:
res += dfs(root.left.left) + dfs(root.left.right)
if root.right:
res += dfs(root.right.left) + dfs(root.right.right)
return max(res, dfs(root.left) + dfs(root.right))
return dfs(root)
我们在上面的代码中使用了lru_cache
加速代码,如果不使用lru_cache
的话会超时。为什么呢?必然有许多重复计算,因为返回值中没有存储当前节点抢还是不抢。
如果我们重新定义函数 f ( r o o t ) f(root) f(root)表示以root
为根节点:1)被抢的最优解 2)不被抢的最优解,此时有两个返回解,最后的结果是这两者的最大值。
class Solution:
def rob(self, root):
"""
:type root: TreeNode
:rtype: int
"""
check = self._rob(root)
return max(check[0], check[1])
def _rob(self, root):
if not root:
return [0, 0]
l, r = self._rob(root.left), self._rob(root.right)
return l[1] + r[1] + root.val, max(l) + max(r)
这个代码的效率比之前的效率高很多。
怎么通过迭代解决这个问题呢?实际上这个问题变成了怎么去遍历一颗二叉树,直接采用bfs
即可。遍历的过程中需要记录当前节点有没有被抢过。
class Solution:
def rob(self, root):
"""
:type root: TreeNode
:rtype: int
"""
stack = [(0, root)]
result = {None: (0, 0)}
while stack:
rob, node = stack.pop()
if not node:
continue
if not rob:
stack.extend([(1, node), (0, node.right), (0, node.left)])
else:
result[node] = (result[node.left][1] + result[node.right][1] + node.val,\
max(result[node.left]) + max(result[node.right]))
return max(result[root])
我将该问题的其他语言版本添加到了我的GitHub Leetcode
如有问题,希望大家指出!!!