传送门
令size(i)表示i子树里有多少个关键点
令sum(i)表示i子树中所有关键点到i的距离和
令Max(i)表示i子树中所有关键点到它的最长链,_Max(i)次长链,Min(i)最短链,_Min(i)次短链
这些都非常好维护,第二问和第三问也很好计算,用最和次拼一下就行了
对于第一问的话,在dp的时候维护一下当前size和sum的乘积就行了
将所有的关键点和它们的lca建出一棵虚树,边权为两点之间的距离
然后按照上面的dp就行了
dp的时候要格外注意子树的根是否是关键点以及儿子的个数
#include
#include
#include
#include
#include
using namespace std;
#define LL long long
#define N 1000005
#define sz 20
int n,q,k,dfs_clock,top;
int tot,point[N],nxt[N*2],v[N*2],c[N*2];
int pt[N],key[N],flag[N],stack[N],h[N],in[N],out[N],f[N][sz+3],size[N];
LL sum[N],Max[N],_Max[N],Min[N],_Min[N],ans1,ans2,ans3;
const LL inf=1e18;
void add(int x,int y,int z)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
}
void build(int x,int fa)
{
h[x]=h[fa]+1;in[x]=++dfs_clock;
for (int i=1;i1]][i-1];
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
f[v[i]][0]=x;
build(v[i],x);
}
out[x]=++dfs_clock;
}
int cmp(int a,int b)
{
return in[a]int lca(int x,int y)
{
if (h[x]int cha=h[x]-h[y];
for (int i=0;iif ((cha>>i)&1) x=f[x][i];
if (x==y) return x;
for (int i=sz-1;i>=0;--i)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void treedp(int x)
{
size[x]=0;
sum[x]=0;
Max[x]=_Max[x]=0;
Min[x]=_Min[x]=inf;
if (key[x]==q) size[x]=1,Min[x]=0;
int cnt=0;
for (int i=point[x];i;i=nxt[i])
{
++cnt;
treedp(v[i]);
ans1+=(LL)size[v[i]]*sum[x]+(LL)size[x]*(sum[v[i]]+(LL)c[i]*(LL)size[v[i]]);
size[x]+=size[v[i]];
sum[x]+=sum[v[i]]+(LL)c[i]*(LL)size[v[i]];
if (Max[v[i]]+(LL)c[i]>Max[x])
{
_Max[x]=Max[x];
Max[x]=Max[v[i]]+(LL)c[i];
}
else _Max[x]=max(_Max[x],Max[v[i]]+(LL)c[i]);
if (Min[v[i]]+(LL)c[i]else _Min[x]=min(_Min[x],Min[v[i]]+(LL)c[i]);
}
if (key[x]==q||cnt>1)
{
ans2=min(ans2,Min[x]+_Min[x]);
ans3=max(ans3,Max[x]+_Max[x]);
}
point[x]=0;
}
int main()
{
scanf("%d",&n);
for (int i=1;iint x,y;scanf("%d%d",&x,&y);
add(x,y,1),add(y,x,1);
}
build(1,0);
memset(point,0,sizeof(point));
scanf("%d",&q);
while (q)
{
scanf("%d",&k);
for (int i=1;i<=k;++i)
{
scanf("%d",&pt[i]);
key[pt[i]]=flag[pt[i]]=q;
}
sort(pt+1,pt+k+1,cmp);pt[0]=k;
for (int i=2;i<=k;++i)
{
int r=lca(pt[i-1],pt[i]);
if (flag[r]!=q)
{
flag[r]=q;
pt[++pt[0]]=r;
}
}
if (flag[1]!=q) flag[1]=q,pt[++pt[0]]=1;
sort(pt+1,pt+pt[0]+1,cmp);
tot=0;stack[top=1]=1;
for (int i=2;i<=pt[0];++i)
{
while (in[pt[i]]stack[top]]||in[pt[i]]>out[stack[top]])
--top;
add(stack[top],pt[i],h[pt[i]]-h[stack[top]]);
stack[++top]=pt[i];
}
ans1=0;ans2=inf;ans3=0;
treedp(1);
printf("%lld %lld %lld\n",ans1,ans2,ans3);
--q;
}
}