给定一个无向、连通的树。树中有 N 个标记为 0…N-1 的节点以及 N-1 条边 。
第 i 条边连接节点 edges[i][0] 和 edges[i][1] 。
返回一个表示节点 i 与其他所有节点距离之和的列表 ans。
示例 1:
输入: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释:
如下为给定的树的示意图:
0
/
1 2
/|
3 4 5
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。
说明: 1 <= N <= 10000
最一开始的思路是计算每个节点与其他节点之间的路径长度,如上例,有以下表格:
0 | 1 | 2 | 3 | 4 | 5 | |
---|---|---|---|---|---|---|
0 | 0 | 1 | 1 | 2 | 2 | 2 |
1 | 1 | 0 | 2 | 3 | 3 | 3 |
2 | 1 | 2 | 0 | 1 | 1 | 1 |
3 | 2 | 3 | 1 | 0 | 2 | 2 |
4 | 2 | 3 | 1 | 2 | 0 | 2 |
5 | 2 | 3 | 1 | 2 | 2 | 0 |
其中第i行第j列表示结点i到j的长度
以上矩阵沿对角线对称,我的代码是:
// leetcode834树中距离之和
vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
for (int i = 0; i < edges.size(); i++)
{
sort(edges[i].begin(), edges[i].end());
}
sort(edges.begin(), edges.end());
vector<vector<int>> Adj;
for (int i = 0; i < N; i++)
{
vector<int>vec;
Adj.push_back(vec);
}
vector<int> ans;
for (int i = 0; i<edges.size(); i++) {
int a = edges[i][0], b = edges[i][1];
Adj[a].push_back(b);
Adj[b].push_back(a);
}
vector<vector<int>> vecPoint(N, vector<int>(N, 0)); // 缓存结点之间的距离
for (int u = 0; u < N; u++) {
int sum = 0;
for (int i = 0; i < N; i++)
{
if (vecPoint[u][i] == 0 && i != u) // 对角线元素不参与运算,已经有结果的项不再参与运算
{
vector<bool> vec(N, false);//其实就是计算u和pos之间的距离
caculateDistance(Adj, vecPoint, u, i, vec);
}
sum += vecPoint[u][i];
}
ans.push_back(sum);
}
return ans;
}
int caculateDistance(vector<vector<int>>& Adj, vector<vector<int>>& vecPoint, int src, int target, vector<bool>&flag)
{
if (flag[src])
return 0;
flag[src] = true;
if (vecPoint[src][target] != 0)
{
return vecPoint[src][target];
}
//兄弟节点
vector<int> points = Adj[src];
for (int i = 0; i < points.size(); i++)
{
int value = points[i];
if (value == target) // 如果兄弟节点中有一个是目标节点,则表示找到,返回1
{
vecPoint[src][target] = vecPoint[value][target] + 1;
vecPoint[target][src] = vecPoint[value][target] + 1;
return vecPoint[src][target];
}
if (vecPoint[value][target] != 0)
{
vecPoint[src][target] = vecPoint[value][target] + 1;
vecPoint[target][src] = vecPoint[value][target] + 1;
return vecPoint[src][target];
}
}
// 如果兄弟节点没找到
for (int i = 0; i < points.size(); i++)
{
int value = caculateDistance(Adj, vecPoint, points[i], target, flag);
if (value != 0)
{
vecPoint[src][target] = value + 1;
vecPoint[target][src] = value + 1;
return vecPoint[src][target];
}
}
return 0;
}
超时了。。。
参考了一下别人的代码:
贴出来分享下:
static const auto init = []() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
return nullptr;
}();
struct Tree{
vector<Tree*> v;
Tree* par;
int val;
Tree(int val){
this->val = val;
}
};
class Solution {
public:
private:
int ans[10001];//存取答案
int node_count[10001];//存转换成树后该节点为父节点的所以节点数量(包括父节点)
public:
//访问过的节点已经成为该节点的父节点了,所以不能将该节点作为访问过节点的父节点
Tree* createTree(vector<vector<int>>& graph,int cur_node,vector<bool>& visited){
Tree* node = new Tree(cur_node);
visited[cur_node] = true;
for(int i = 0;i<graph[cur_node].size();i++){
if(!visited[graph[cur_node][i]])
node->v.push_back(createTree(graph,graph[cur_node][i],visited));
}
return node;
}
//计算每个节点子节点数量(包括该节点)
int dfs(Tree* root){
if(root==NULL){
return 0;
}
int res = 0;
for(int i = 0;i<root->v.size();i++){
res=res+dfs(root->v[i])+1;
}
if(node_count[root->val]==0) node_count[root->val]=res+1;
return res;
}
//计算root节点答案
int dfs1(Tree* root){
if(root==NULL){
return 0;
}
int res = 0;
for(int i = 0;i<root->v.size();i++){
res=res+dfs1(root->v[i])+node_count[root->v[i]->val];
}
return res;
}
//根据父节点来计算字节点答案
void solve(Tree* root,int N){
if(root == NULL){
return;
}
for(int i = 0;i<root->v.size();i++){
ans[root->v[i]->val] = ans[root->val]+N-2*node_count[root->v[i]->val];
solve(root->v[i],N);
}
}
vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
vector<int> vans(N,0);
if(edges.size()==0) return vans;
vector<vector<int>> graph(N);
memset(ans,-1,sizeof(ans));
memset(node_count,0,sizeof(node_count));
vector<bool> visited(N,false);
int cur_node = -1;
for(int i = 0;i < edges.size();i++){
graph[edges[i][0]].push_back(edges[i][1]);
graph[edges[i][1]].push_back(edges[i][0]);
if(cur_node==-1) cur_node = edges[i][0];
}
Tree* root = NULL;
while(root==NULL){
root = createTree(graph,cur_node,visited);
cur_node++;
}
dfs(root);
ans[root->val] = dfs1(root);
solve(root,N);
for(int i = 0;i<N;i++){
vans[i] = ans[i];
}
return vans;
}
};
大体的思路是计算每个结点的子节点数(包括它自己),然后先计算根节点的结果,根据根节点的结果,计算每一个子节点的结果