leetcode-834. 树中距离之和

题目

给定一个无向、连通的树。树中有 N 个标记为 0…N-1 的节点以及 N-1 条边 。

第 i 条边连接节点 edges[i][0] 和 edges[i][1] 。

返回一个表示节点 i 与其他所有节点距离之和的列表 ans。

示例 1:

输入: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 
如下为给定的树的示意图:
  0
 / \
1   2
   /|\
  3 4 5

我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) 
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

说明: 1 <= N <= 10000

解题思路

看了题解做的,树形dp

首先求解简单一些的问题,即给定根节点,求根节点到其他节点的距离和。以示例的树为例,先求0到其他节点的距离和,应该为:1的距离和+2的距离和+0的子树节点数量。12的距离和比较好理解,加上子树的节点数量是因为,从0走到子树的任意一个节点上,其距离都比从1或者2出发多了1,所以加上节点数量即可。

对每个节点保存一个tuple,分别是以该节点为根的树的距离和,以及以该节点为根的树的节点数量(包括节点本身),则上述的树可以写为:

   0
 /    \
1       2
	  / | \
	3   4   5
(0,1) (0,1) (0,1)

下一步写为:

	   0
	 /    \
	1       2
(0,1)     (3,4) = ((0+1)*3, 1*3+1)
		  / | \
		3   4   5
	(0,1) (0,1) (0,1)

更新到根节点有:

	   0
	  (8,6) = (0+3+1+4, 1+4+1)
	 /    \
	1       2
(0,1)     (3,4)
		  / | \
		3   4   5
	(0,1) (0,1) (0,1)

以0为根节点,最终的答案就是8

这样计算出来的是以其中1个节点作为根节点的结果,以所有节点作为根节点,则可以计算出最终答案,这种方法的时间复杂度是 o ( n 2 ) o(n^2) o(n2)。求的时候用后序遍历即可。

观察0的子节点2可以发现,以2为根节点时,其实就是到2的子树的所有节点距离都缩短了1,到除了2的子树以外的节点都增加了1,所以ans[2] = ans[0] - num_child[2] + (N - num_child[2])。所以求其他节点时,不用重新计算,直接在已经计算出来的基础上变换即可。由于每次变换的基础都是父节点和子节点,所以用先序遍历,用父节点的值调整子节点的值即可。

最终的时间复杂度是 o ( n ) o(n) o(n)

注意点
由于上述的思路都是基于树的,而题目给出的是边,所以要从边中构造一棵树。因为给定的边的顺序不一致,所以先按照图的方式保存相连的节点,选定一个节点用BFS/DFS调整边的顺序,使得按照这种顺序从根节点开始画一棵树。

代码

class Solution:
    def sumOfDistancesInTree(self, N: int, edges: List[List[int]]) -> List[int]:
        if not edges:
            return [0]
        graph = {}
        for from_node, to_node in edges:
            if from_node not in graph:
                graph[from_node] = []
            graph[from_node].append(to_node)
            if to_node not in graph:
                graph[to_node] = []
            graph[to_node].append(from_node)
        # modify nodes order to form a tree
        node_dict = {}
        stack = [edges[0][0]]
        visited_nodes = set()
        while stack:
            node = stack.pop()
            if node in visited_nodes:
                continue
            visited_nodes.add(node)
            if list(set(graph[node]) - visited_nodes):
                node_dict[node] = list(set(graph[node]) - visited_nodes)
            stack += node_dict.get(node, [])
        # postorder
        ans = [0] * N
        child_num = [1] * N
        stack = [(edges[0][0], 0)]
        while stack:
            node, stat = stack.pop()
            if stat == 0:
                stack.append((node, 1))
                if node in node_dict:
                    for child_node in node_dict[node]:
                        stack.append((child_node, 0))
            else:
                if node not in node_dict:
                    continue
                ans[node] = sum(ans[child_node] + child_num[child_node] for child_node in node_dict[node])
                child_num[node] = 1 + sum(child_num[child_node] for child_node in node_dict[node])
        # preorder
        stack = [edges[0][0]]
        while stack:
            node = stack.pop()
            if node in node_dict:
                for child_node in node_dict[node]:
                    stack.append(child_node)
                    ans[child_node] = ans[node] - child_num[child_node] + (N - child_num[child_node])
        return ans

你可能感兴趣的:(OJ题目记录)