考场上写了一大坨树形DP,写的时候就感觉我这不是跟求树的最长链写的一毛一样
然后考后看题解,果然是k个ren所连成的子树的最长链的一半
可以利用反证法证明,如果在长度为d的最长链的中间放一个中心,如果有另外一个点到这个点的长度>(d+1)/2,那么这个点到对面的那个点的长度大于d,所以不存在这样一个点。
#include
#define maxl 100010
using namespace std;
int n,cnt,k,ans;
int a[maxl],ehead[maxl],dis[maxl];
struct ed
{
int to,nxt;
}e[maxl*2];
bool in[maxl],vis[maxl];
inline void add(int u,int v)
{
e[++cnt].to=v;e[cnt].nxt=ehead[u];ehead[u]=cnt;
}
inline void prework()
{
for(int i=1;i<=n;i++)
ehead[i]=0,in[i]=false;
int u,v;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
for(int i=1;i<=k;i++)
{
scanf("%d",&a[i]);
in[a[i]]=true;
}
}
inline void dfs(int u)
{
vis[u]=true;
int v;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(!vis[v])
{
dis[v]=dis[u]+1;
dfs(v);
}
}
}
inline void mainwork()
{
if(k<=1)
{
ans=0;
return;
}
for(int i=1;i<=n;i++)
vis[i]=false,dis[i]=0;
dfs(a[1]);
int u=a[1];
for(int i=1;i<=k;i++)
if(dis[a[i]]>dis[u])
u=a[i];
for(int i=1;i<=n;i++)
vis[i]=false,dis[i]=0;
dfs(u);
int v=u;
for(int i=1;i<=k;i++)
if(dis[a[i]]>dis[v])
v=a[i];
ans=(dis[v]+1)/2;
}
inline void print()
{
printf("%d\n",ans);
}
int main()
{
while(~scanf("%d%d",&n,&k))
{
prework();
mainwork();
print();
}
return 0;
}
树形DP版本,又臭又长:
#include
#define maxl 100010
using namespace std;
const int inf=2e9;
int n,k,cnt,ans;
int ehead[maxl];
int f[maxl],mxfa[maxl],mxs[maxl],secmxs[maxl];
bool mxflag[maxl],secmxflag[maxl],faflag[maxl];
struct ed
{
int to,nxt,mi;
}e[maxl<<1];
bool in[maxl];
inline void add(int u,int v)
{
e[++cnt].to=v;e[cnt].nxt=ehead[u];ehead[u]=cnt;
}
inline void prework()
{
for(int i=1;i<=n;i++)
ehead[i]=0,f[i]=0,in[i]=false;
int u,v;
cnt=0;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
for(int i=1;i<=k;i++)
{
scanf("%d",&u);
in[u]=true;
}
}
inline void gets(int u,int fa)
{
int v,tmp;
if(in[u])
mxflag[u]=true;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(v==fa) continue;
gets(v,u);
if(in[v])
{
in[u]=true;
if(mxflag[v])
{
tmp=mxs[v]+1;
if(tmp>mxs[u])
{
secmxs[u]=mxs[u],mxs[u]=tmp;
secmxflag[u]=mxflag[u];
mxflag[u]=true;
}
else if(tmp>secmxs[u])
secmxs[u]=tmp,secmxflag[u]=true;
}
}
}
}
inline void getf(int u,int fa)
{
int v,tmp;
if(u==1) f[1]=mxs[1];
else
{
if(mxs[u]+1==mxs[fa])
{
if(secmxflag[fa])
{
mxfa[u]=max(mxfa[u],secmxs[fa]+1);
faflag[u]=true;
}
}
else
{
if(mxflag[fa])
{
mxfa[u]=max(mxfa[u],mxs[fa]+1);
faflag[u]=true;
}
}
if(faflag[fa])
{
mxfa[u]=max(mxfa[u],mxfa[fa]+1);
faflag[u]=true;
}
}
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(v==fa) continue;
getf(v,u);
}
if(faflag[u])
f[u]=mxfa[u];
if(mxs[u])
f[u]=max(f[u],mxs[u]);
}
inline void mainwork()
{
for(int i=1;i<=n;i++){
f[i]=0,mxs[i]=0,secmxs[i]=0,mxfa[i]=0;
mxflag[i]=secmxflag[i]=faflag[i]=false;
}
gets(1,0);
mxfa[1]=0;
getf(1,0);
ans=inf;
for(int i=1;i<=n;i++)
ans=min(ans,f[i]);
}
inline void print()
{
printf("%lld\n",ans);
}
int main()
{
while(~scanf("%d%d",&n,&k))
{
prework();
mainwork();
print();
}
return 0;
}