先随便找个点dfs一次找到距离最远的点,再从那个点再同样dfs一次,那个点到其他点最长的距离即为树的直径
不过这种方法不适用于有负权边
#include
#include
#include
#include
#include
#define REP(i, a, b) for(register int i = (a); i < (b); i++)
#define _for(i, a, b) for(register int i = (a); i <= (b); i++)
using namespace std;
const int MAXN = 1e5;
struct node { int v, w; };
vector g[MAXN];
int d[MAXN], n, m;
void dfs(int u, int fa)
{
REP(i, 0, g[u].size())
{
int v = g[u][i].v;
if(v == fa) continue;
d[v] = d[u] + g[u][i].w;
dfs(v, u);
}
}
int main()
{
while(~scanf("%d%d", &n, &m))
{
_for(i, 1, n) g[i].clear();
while(m--)
{
int u, v, w; char s[5];
scanf("%d%d%d%s", &u, &v, &w, s);
g[u].push_back(node{v, w});
g[v].push_back(node{u, w});
}
memset(d, 0, sizeof(d));
dfs(1, -1);
int ans = 0, p;
_for(i, 1, n)
if(ans < d[i])
{
ans = d[i];
p = i;
}
memset(d, 0, sizeof(d));
dfs(p, -1);
ans = 0;
_for(i, 1, n)
if(ans < d[i])
ans = d[i];
printf("%d\n", ans);
}
return 0;
}
还可以用树形dp
树的直径是由其中一个端点到其他端点的最远距离和次远距离组成的
可以用这个性质来dfs
树形dp其实更好写
#include
#include
#include
#include
#include
#define REP(i, a, b) for(register int i = (a); i < (b); i++)
#define _for(i, a, b) for(register int i = (a); i <= (b); i++)
using namespace std;
const int MAXN = 1e5;
struct node { int v, w; };
vector g[MAXN];
int n, m, ans;
int dp[MAXN][2];
void dfs(int u, int fa)
{
REP(i, 0, g[u].size())
{
int v = g[u][i].v, w = g[u][i].w;
if(v == fa) continue;
dfs(v, u);
if(dp[u][1] < dp[v][1] + w)
{
dp[u][0] = dp[u][1];
dp[u][1] = dp[v][1] + w;
}
else if(dp[u][0] < dp[v][1] + w)
dp[u][0] = dp[v][1] + w;
}
ans = max(ans, dp[u][0] + dp[u][1]);
}
int main()
{
while(~scanf("%d%d", &n, &m))
{
memset(dp, 0, sizeof(dp));
_for(i, 1, n) g[i].clear();
while(m--)
{
int u, v, w; char s[5];
scanf("%d%d%d%s", &u, &v, &w, s);
g[u].push_back(node{v, w});
g[v].push_back(node{u, w});
}
ans = 0;
dfs(1, -1);
printf("%d\n", ans);
}
return 0;
}
然后发现其实dp数组可以省去,用两个变量存一下就可以了。
#include
#include
#include
#include
#include
#define REP(i, a, b) for(register int i = (a); i < (b); i++)
#define _for(i, a, b) for(register int i = (a); i <= (b); i++)
using namespace std;
const int MAXN = 1e5;
struct node { int v, w; };
vector g[MAXN];
int n, m, ans;
int dfs(int u, int fa)
{
int max1 = 0, max2 = 0;
REP(i, 0, g[u].size())
{
int v = g[u][i].v, w = g[u][i].w;
if(v == fa) continue;
int now = dfs(v, u) + w;
if(max1 < now) max2 = max1, max1 = now;
else if(max2 < now) max2 = now;
}
ans = max(ans, max1 + max2);
return max1;
}
int main()
{
while(~scanf("%d%d", &n, &m))
{
_for(i, 1, n) g[i].clear();
while(m--)
{
int u, v, w; char s[5];
scanf("%d%d%d%s", &u, &v, &w, s);
g[u].push_back(node{v, w});
g[v].push_back(node{u, w});
}
ans = 0;
dfs(1, -1);
printf("%d\n", ans);
}
return 0;
}
做树上问题一定要多画图。
这道题其实画画图就出来了
首先连边可以成环,可以让环上的路径只走一次。
那么走最少肯定就连直径啦。
关键是第二条怎么连
我们可以试着连一下
可以发现两个环重合的部分走了两次。
那这个怎么处理。
如果还是一样做的话重叠的部就从走1次变成1 - 1 = 0次了
那么因为应该走两次
所以我们可以把第一次环路径上的边权设为-1
那么重叠的部分就是1-(-1) = 2
符合我们推出的结论。
#include
#define REP(i, a, b) for(register int i = (a); i < (b); i++)
#define _for(i, a, b) for(register int i = (a); i <= (b); i++)
using namespace std;
const int MAXN = 1e5 + 10;
struct Edge { int to, w, next; };
Edge e[MAXN << 1];
int head[MAXN], tot, n, k, ans, s;
int son1[MAXN], son2[MAXN], res;
void AddEdge(int from, int to)
{
e[tot] = Edge{to, 1, head[from]};
head[from] = tot++;
}
int dfs(int u, int fa)
{
int max1 = 0, max2 = 0;
for(int i = head[u]; ~i; i = e[i].next)
{
int v = e[i].to, w = e[i].w;
if(v == fa) continue;
int now = dfs(v, u) + w;
if(max1 < now)
{
max2 = max1, max1 = now;
son2[u] = son1[u]; son1[u] = i;
}
else if(max2 < now) max2 = now, son2[u] = i;
}
if(ans < max1 + max2) ans = max1 + max2, s = u;
return max1;
}
void work()
{
memset(son1, -1, sizeof(son1));
memset(son2, -1, sizeof(son2));
ans = 0;
dfs(1, -1);
res -= ans - 1;
}
int main()
{
memset(head, -1, sizeof(head));
tot = 0; s = 1;
scanf("%d%d", &n, &k);
REP(i, 1, n)
{
int u, v;
scanf("%d%d", &u, &v);
AddEdge(u, v); AddEdge(v, u);
}
res = (n - 1) * 2;
work();
if(k > 1)
{
for(int i = son1[s]; ~i; i = son1[e[i].to]) e[i].w = e[i^1].w = -1;
for(int i = son2[s]; ~i; i = son1[e[i].to]) e[i].w = e[i^1].w = -1; //注意是i = son1[e[i].to],不是son2
work();
}
printf("%d\n", res);
return 0;
}