给定一个无向、连通的树。树中有 N 个标记为 0…N-1 的节点以及 N-1 条边 。
第 i 条边连接节点 edges[i][0] 和 edges[i][1] 。
返回一个表示节点 i 与其他所有节点距离之和的列表 ans。
示例 1:
输入: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释:
如下为给定的树的示意图:
0
/ \
1 2
/|\
3 4 5
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。
说明: 1 <= N <= 10000
看了题解做的,树形dp
首先求解简单一些的问题,即给定根节点,求根节点到其他节点的距离和。以示例的树为例,先求0到其他节点的距离和,应该为:1
的距离和+2
的距离和+0
的子树节点数量。1
和2
的距离和比较好理解,加上子树的节点数量是因为,从0走到子树的任意一个节点上,其距离都比从1
或者2
出发多了1,所以加上节点数量即可。
对每个节点保存一个tuple,分别是以该节点为根的树的距离和,以及以该节点为根的树的节点数量(包括节点本身),则上述的树可以写为:
0
/ \
1 2
/ | \
3 4 5
(0,1) (0,1) (0,1)
下一步写为:
0
/ \
1 2
(0,1) (3,4) = ((0+1)*3, 1*3+1)
/ | \
3 4 5
(0,1) (0,1) (0,1)
更新到根节点有:
0
(8,6) = (0+3+1+4, 1+4+1)
/ \
1 2
(0,1) (3,4)
/ | \
3 4 5
(0,1) (0,1) (0,1)
以0为根节点,最终的答案就是8
这样计算出来的是以其中1个节点作为根节点的结果,以所有节点作为根节点,则可以计算出最终答案,这种方法的时间复杂度是 o ( n 2 ) o(n^2) o(n2)。求的时候用后序遍历即可。
观察0
的子节点2
可以发现,以2
为根节点时,其实就是到2
的子树的所有节点距离都缩短了1,到除了2
的子树以外的节点都增加了1,所以ans[2] = ans[0] - num_child[2] + (N - num_child[2])
。所以求其他节点时,不用重新计算,直接在已经计算出来的基础上变换即可。由于每次变换的基础都是父节点和子节点,所以用先序遍历,用父节点的值调整子节点的值即可。
最终的时间复杂度是 o ( n ) o(n) o(n)
注意点
由于上述的思路都是基于树的,而题目给出的是边,所以要从边中构造一棵树。因为给定的边的顺序不一致,所以先按照图的方式保存相连的节点,选定一个节点用BFS/DFS调整边的顺序,使得按照这种顺序从根节点开始画一棵树。
class Solution:
def sumOfDistancesInTree(self, N: int, edges: List[List[int]]) -> List[int]:
if not edges:
return [0]
graph = {}
for from_node, to_node in edges:
if from_node not in graph:
graph[from_node] = []
graph[from_node].append(to_node)
if to_node not in graph:
graph[to_node] = []
graph[to_node].append(from_node)
# modify nodes order to form a tree
node_dict = {}
stack = [edges[0][0]]
visited_nodes = set()
while stack:
node = stack.pop()
if node in visited_nodes:
continue
visited_nodes.add(node)
if list(set(graph[node]) - visited_nodes):
node_dict[node] = list(set(graph[node]) - visited_nodes)
stack += node_dict.get(node, [])
# postorder
ans = [0] * N
child_num = [1] * N
stack = [(edges[0][0], 0)]
while stack:
node, stat = stack.pop()
if stat == 0:
stack.append((node, 1))
if node in node_dict:
for child_node in node_dict[node]:
stack.append((child_node, 0))
else:
if node not in node_dict:
continue
ans[node] = sum(ans[child_node] + child_num[child_node] for child_node in node_dict[node])
child_num[node] = 1 + sum(child_num[child_node] for child_node in node_dict[node])
# preorder
stack = [edges[0][0]]
while stack:
node = stack.pop()
if node in node_dict:
for child_node in node_dict[node]:
stack.append(child_node)
ans[child_node] = ans[node] - child_num[child_node] + (N - child_num[child_node])
return ans