给定一个大小为 n n n的无根树,给定一个大小为 K K K的关键点集,求从每个点出发经过这些点的代价和(可以不返回)
数据范围: n ≤ 5 × 1 0 5 n\leq 5\times 10^5 n≤5×105
快乐树形 d p dp dp
一个非常关键的结论(建议看懂这句话之后再看下面的题解)
从第 i i i个点出发的答案即为 i i i与所有关键点组成虚树的长度和的两倍-该虚树直径【后面那个减得原因是因为不能返回】
原因是每条边显然经过两次,除了直径
由于树无根,所以我们选定一个关键点作为根,记为 r t rt rt
设 s i z [ x ] siz[x] siz[x]表示 x x x的子树中关键点的个数。作用:判断子树是否存在关键点
l e n [ i ] len[i] len[i]虚树所有点到 i i i的距离
f [ i ] f[i] f[i]在两遍 d f s dfs dfs中有不同的意义,第一遍是指以 r t rt rt为根的最长链,第二遍直接表示虚树内直径
g [ i ] g[i] g[i]直接表示以 r t rt rt为根的树上次长链,利用第一遍结束后的 f , g f,g f,g可以求直径,当然你也可以多跑两遍 d f s dfs dfs来求,做法是相似的
M a x _ s o n [ x ] Max\_son[x] Max_son[x]表示第一遍 d f s dfs dfs的 f [ x ] f[x] f[x]指向的是哪个儿子,同样应用于求虚树直径
最终答案即为 A n s i = 2 l e n i − f i Ans_i=2len_i-f_i Ansi=2leni−fi
时间复杂度: O ( n ) O(n) O(n)
#include
#include
#define LL long long
#define N 500010
using namespace std;int n,k,l[N],tot,Max_son[N],rt;
bool v[N];
struct node{
int next,to;LL w;}e[N<<1];
inline void add(int u,int v,LL w){
e[++tot]=(node){
l[u],v,w};l[u]=tot;return;}
LL ans,z,f[N],g[N],siz[N],len[N];
inline LL read()
{
char c;LL d=1,f=0;
while(c=getchar(),!isdigit(c)) if(c=='-') d=-1;f=(f<<3)+(f<<1)+c-48;
while(c=getchar(),isdigit(c)) f=(f<<3)+(f<<1)+c-48;
return d*f;
}
inline void dfs1(int x,int fa=-1)
{
siz[x]=v[x];
for(register int i=l[x];i;i=e[i].next)
{
int y=e[i].to;LL w=e[i].w;
if(y==fa) continue;
dfs1(y,x);
siz[x]+=siz[y];len[x]+=len[y];
if(siz[y])
{
len[x]+=w;
if(f[y]+w>=f[x]) g[x]=f[x],f[x]=f[y]+w,Max_son[x]=y;
else if(f[y]+w>=g[x]) g[x]=f[y]+w;
}
}
return;
}
inline void dfs2(int x,int fa=-1)
{
for(register int i=l[x];i;i=e[i].next)
{
int y=e[i].to;LL w=e[i].w,Diam;
if(y==fa) continue;
len[y]=len[x];
if(siz[y]==0) len[y]+=w;
if(siz[y]==k) len[y]-=w;
if(Max_son[x]==y) Diam=g[x]+w;//如果这个儿子已经是最长链,则当前Diam=次长链的长度+w
else Diam=f[x]+w;//否则直接用最长链
if(Diam>=f[y]) g[y]=f[y],f[y]=Diam,Max_son[y]=0;
else if(Diam>=g[y]) g[y]=Diam;
dfs2(y,x);
}
return;
}
signed main()
{
n=read();k=read();
for(register int i=1,x,y;i<n;i++) x=read(),y=read(),z=read(),add(x,y,z),add(y,x,z);
for(register int i=1,x;i<=k;i++) v[rt=read()]=true;
dfs1(rt);dfs2(rt);
for(register int i=1;i<=n;i++) printf("%lld\n",2*len[i]-f[i]);
}