题目大意:给定一棵无根树,边权都是1,去掉一条边并加上一条新边,输出所有可能的新树的直径中的最小值和最大值
一看就知道肯定是treeDP了
我们可以考虑对于每一条边,假设把它删掉,会得到两棵小树
此时新树的最大直径肯定是把两个小树的直径连起来,使得L=L1+L2+1
而最小直径肯定是连接两个直径的中点,使得L≈max(L1,L2)
所以我们需要得知对于每条边两边的小树的直径分别是多少,这个可以用两次treeDP求出
第一次就是以1为根求出每个子树的直径,这个很好求
那第二次怎么计算整个的树中刨去一个子树后剩下的那棵树的直径呢?
这棵树的直径一定是由分割点的父节点转移而来,我们只需求出在这次扩展之后新产生的链,即过这个父节点的最长链
所以在第一次treeDP时,还需要维护对于每个子树以根为端点的最长链,和次长链
那如果现在要求的就是刨去包含最长链或次长链的那棵子树的直径呢?
所以还要维护第三长链.....
这样就知道要删哪条边了
然后第二问就随便搞搞(可以O(N)暴力),求出要在哪两个点之间加边就好了
#include
#include
#include
#define N 1000010
using namespace std;
int to[N],nxt[N],pre[N],cnt;
void ae(int ff,int tt)
{
cnt++;
to[cnt]=tt;
nxt[cnt]=pre[ff];
pre[ff]=cnt;
}
int fa[N],d1[N],fir[N],sec[N],thr[N];
int fd[N],sd[N];
void dfs(int x)
{
int i,j;
for(i=pre[x];i;i=nxt[i])
{
j=to[i];
if(j==fa[x]) continue;
fa[j]=x;
dfs(j);
thr[x]=max(fir[j]+1,thr[x]);
if(thr[x]>sec[x]) swap(thr[x],sec[x]);
if(sec[x]>fir[x]) swap(sec[x],fir[x]);
d1[x]=max(d1[x],d1[j]);
sd[x]=max(d1[j],sd[x]);
if(sd[x]>fd[x]) swap(sd[x],fd[x]);
}
d1[x]=max(d1[x],fir[x]+sec[x]);
}
int minn=707185547,x4,y4;
int maxn,x5,y5;
int cal(int x,int y)
{
if(x>y) swap(x,y);
int yy=y;
x=x/2+x%2;y=y/2+y%2;
return max(x+y+1,yy);
}
void solve(int x,int d2,int lg)
{
// cout<maxn) maxn=d1[x]+d2+1,x5=fa[x],y5=x;
}
int i,j;
for(i=pre[x];i;i=nxt[i])
{
j=to[i];
if(j==fa[x]) continue;
t=fir[x];tmp=fd[x];
if(tmp==d1[j]) tmp=sd[x];
if(t==fir[j]+1)
{
t=sec[x];
tmp=max(tmp,sec[x]+max(thr[x],lg));
}
else if(sec[x]==fir[j]+1) tmp=max(tmp,fir[x]+max(thr[x],lg));
else tmp=max(tmp,fir[x]+max(sec[x],lg));
t=max(lg,t);tmp=max(tmp,d2);
solve(j,tmp,t+1);
}
}
int t[5],T,ma;
bool del[N];
void dfs1(int x,int ff,int dd)
{
int i,j;
if(dd>ma) ma=dd,t[T]=x;
for(i=pre[x];i;i=nxt[i])
{
j=to[i];
if(j==ff||del[i]) continue;
dfs1(j,x,dd+1);
}
}
int ans[3],TT;
bool getans(int x,int ff,int dd)
{
if(x==t[T])
{
if(ma==0) ans[TT]=x;
return true;
}
int i,j;
for(i=pre[x];i;i=nxt[i])
{
j=to[i];
if(j==ff||del[i]) continue;
if(getans(j,x,dd+1))
{
if(dd==ma/2) ans[TT]=x;
return true;
}
}
return false;
}
void findmin(int x,int y)
{
int i;
for(i=pre[x];i;i=nxt[i])
if(to[i]==y) del[i]=true;
for(i=pre[y];i;i=nxt[i])
if(to[i]==x) del[i]=true;
ma=-1;T=1;
dfs1(x,y,0);
ma=-1;T=2;
dfs1(t[1],0,0);
TT=1;
getans(t[1],0,0);
ma=-1;T=3;
dfs1(y,x,0);
ma=-1;T=4;
dfs1(t[3],0,0);
TT=2;
getans(t[3],0,0);
printf("%d %d %d %d %d\n",minn,x,y,ans[1],ans[2]);
for(i=pre[x];i;i=nxt[i])
if(to[i]==y) del[i]=false;
for(i=pre[y];i;i=nxt[i])
if(to[i]==x) del[i]=false;
}
void findmax(int x,int y)
{
int i;
for(i=pre[x];i;i=nxt[i])
if(to[i]==y) del[i]=true;
for(i=pre[y];i;i=nxt[i])
if(to[i]==x) del[i]=true;
ma=-1;T=1;
dfs1(x,y,0);
ma=-1;T=2;
dfs1(t[1],0,0);
ma=-1;T=3;
dfs1(y,x,0);
ma=-1;T=4;
dfs1(t[3],0,0);
printf("%d %d %d %d %d",maxn,x,y,t[1],t[3]);
}
int main()
{
int n,m;
scanf("%d",&n);
int i,j,x,y;
for(i=1;i