题解:虚树部分参见上一篇博客
然后DP部分随便乱搞就过了。
代码:
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define N 1001000 #define LOGN 22 #define inf 0x3f3f3f3f #define INF 0x3f3f3f3f3f3f3f3fLL using namespace std; struct KSD { int v,next; }e[N<<1]; int head[N],cnt; void add(int u,int v) { cnt++; e[cnt].v=v; e[cnt].next=head[u]; head[u]=cnt; } int fa[N][LOGN]; int l[N][LOGN]; int deep[N],pos[N]; void dfs(int x,int p) { int i,v; deep[x]=deep[p]+1; pos[x]=++cnt; for(i=head[x];i;i=e[i].next) { v=e[i].v; if(v==p)continue; fa[v][0]=x; l[v][0]=1; dfs(v,x); } } void array(int n,int logn=LOGN-1) { int i,j; for(j=1;j<=logn;j++) for(i=1;i<=n;i++) { fa[i][j]=fa[fa[i][j-1]][j-1]; l[i][j]=l[i][j-1]+l[fa[i][j-1]][j-1]; } } inline int getlca(int x,int y,int logn=LOGN-1) { int i; if(deep[x]<deep[y])swap(x,y); for(i=logn;i>=0;i--) if(deep[fa[x][i]]>=deep[y]) x=fa[x][i]; if(x==y)return x; for(i=logn;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } inline int getlen(int x,int y,int logn=LOGN-1) { int ans=0; if(deep[x]<deep[y])swap(x,y); for(int i=logn;i>=0;i--) if(deep[fa[x][i]]>=deep[y]) { ans+=l[x][i]; x=fa[x][i]; } return ans; } struct Graph { struct AKL { int v,next,len; }e[N]; int head[N],cnt; int vis[N],yyc[N],T; int size[N],sizesum; long long lmin[N],lmin2[N],lmax[N],lmax2[N]; long long anssum,ansmin,ansmax; void add(int u,int v,int len) { e[++cnt].v=v; e[cnt].len=len; if(vis[u]!=T)vis[u]=T,e[cnt].next=0; else e[cnt].next=head[u]; head[u]=cnt; } void set(int x){yyc[x]=T,sizesum++;} void dfs(int x) { int i,v,temp; lmax[x]=lmax2[x]=yyc[x]==T?0:-inf; lmin[x]=lmin2[x]=inf; size[x]=(yyc[x]==T); if(vis[x]==T) { for(i=head[x];i;i=e[i].next) { v=e[i].v; dfs(v); size[x]+=size[v]; temp=e[i].len+(yyc[v]==T?0:lmin[v]); if(temp<lmin[x]) { lmin2[x]=lmin[x]; lmin[x]=temp; } else if(temp<lmin2[x])lmin2[x]=temp; temp=lmax[v]+e[i].len; if(temp>lmax[x]) { lmax2[x]=lmax[x]; lmax[x]=temp; } else if(temp>lmax2[x])lmax2[x]=temp; if(yyc[x]==T)ansmin=min(ansmin,lmin[x]); else ansmin=min(ansmin,lmin[x]+lmin2[x]); ansmax=max(ansmax,lmax[x]+lmax2[x]); anssum+=(long long)e[i].len*size[v]*(sizesum-size[v]); } } } void dp() { anssum=0,ansmin=inf,ansmax=0; dfs(1); printf("%lld %lld %lld\n",anssum,ansmin,ansmax); } }G; int n,q; struct Lux { int x,pos; Lux(int _x=0,int _pos=0):x(_x),pos(_pos){} bool operator < (const Lux &a)const{return pos<a.pos;} }lux[N]; int stk[N],top; int main() { // freopen("test.in","r",stdin); int i,j,k; int a,b,c; scanf("%d",&n); for(i=1;i<n;i++) { scanf("%d%d",&a,&b); add(a,b),add(b,a); } cnt=0; dfs(1,0); array(n); for(scanf("%d",&q);q--;) { G.T++,G.cnt=0,G.sizesum=0; for(scanf("%d",&n),i=1;i<=n;i++) scanf("%d",&a),lux[i]=Lux(a,pos[a]),G.set(a); sort(lux+1,lux+n+1); stk[top=1]=1; for(i=1;i<=n;i++) { int lca=getlca(stk[top],lux[i].x); while(deep[lca]<deep[stk[top]]) { if(deep[stk[top-1]]<=deep[lca]) { int last=stk[top--]; if(stk[top]!=lca)stk[++top]=lca; G.add(lca,last,getlen(lca,last)); break; } G.add(stk[top-1],stk[top],getlen(stk[top-1],stk[top])),top--; } if(stk[top]!=lux[i].x)stk[++top]=lux[i].x; } while(top>1)G.add(stk[top-1],stk[top],getlen(stk[top-1],stk[top])),top--; G.dp(); } return 0; }