看完题意,显然是虚树。
建出虚树后,可以容易地求出虚树上的点会被哪一个点管辖,关键在于不在虚树上的点归属于哪个点,我们分类讨论不在虚树上的点的贡献:
我们先假设虚树上的点全是关键点,注意后文的子树都是原树的子树。
实现时,我们通过一个向上和一个向下的 d p dp dp求出虚树上点的归属。
然后再对于每个点 x x x,枚举其出边 v v v,求出 m i d mid mid,计算 x , v x,v x,v的新增贡献。
并且记录一个 g x g_x gx表示 2 , 3 2,3 2,3类的答案,初始为子树大小,枚举出边 v v v时,把 x x x包含 v v v的儿子的子树结点个数去掉,最后让 x x x的贡献加上 g x g_x gx即可。
时间复杂度 O ( n l g n ) O(nlgn) O(nlgn)。
有一个实现过程中的小 t r i c k trick trick是建虚树时直接把 1 1 1结点放入虚树,会大大减少一些不必要的分类讨论。
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
//#include
//#include
//#include
#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se second
using namespace std;
template<typename T>inline bool upmin(T &x,T y) {
return y<x?x=y,1:0; }
template<typename T>inline bool upmax(T &x,T y) {
return x<y?x=y,1:0; }
typedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int,int> PR;
typedef vector<int> VI;
const lod eps=1e-11;
const lod pi=acos(-1);
const int oo=1<<30;
const ll loo=1ll<<62;
const int mods=1e9+7;
const int MAXN=600005;
const int INF=0x3f3f3f3f;//1061109567
/*--------------------------------------------------------------------*/
inline int read()
{
int f=1,x=0; char c=getchar();
while (c<'0'||c>'9') {
if (c=='-') f=-1; c=getchar(); }
while (c>='0'&&c<='9') {
x=(x<<3)+(x<<1)+(c^48); c=getchar(); }
return x*f;
}
PR mn[MAXN];
vector<int> e[MAXN],E[MAXN];
int a[MAXN],b[MAXN],f[MAXN],g[MAXN],stk[MAXN],top=0,n,m;
int dep[MAXN],sz[MAXN],Log[MAXN],dfn[MAXN],fa[MAXN][20],head[MAXN],flag[MAXN],DFN=0,edgenum;
int getlca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int i=Log[dep[x]];i>=0;i--)
if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
for (int i=Log[dep[x]];i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int jump(int x,int d)
{
for (int i=Log[dep[x]];i>=0;i--)
if (dep[fa[x][i]]>=d) x=fa[x][i];
return x;
}
void dfs(int x,int father)
{
fa[x][0]=father,sz[x]=1,dep[x]=dep[father]+1,dfn[x]=++DFN;
for (int i=1;i<=Log[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (auto v:e[x]) if (v!=father) dfs(v,x),sz[x]+=sz[v];
}
void Init()
{
dep[0]=-1,Log[1]=0;
for (int i=1;i<=n;i++) Log[i]=Log[i>>1]+1;
dfs(1,0);
}
void add(int u,int v) {
E[u].PB(v); }
void build()
{
sort(a+1,a+m+1,[&](int x,int y){
return dfn[x]<dfn[y]; });
stk[top=1]=1;
for (int i=1+(a[1]==1);i<=m;i++)
{
int lca=getlca(stk[top],a[i]);
while (top>1&&dep[stk[top-1]]>dep[lca]) add(stk[top-1],stk[top]),top--;
if (dep[stk[top]]>dep[lca]) add(lca,stk[top--]);
if (!top||stk[top]!=lca) stk[++top]=lca;
stk[++top]=a[i];
}
while (top>1) add(stk[top-1],stk[top]),top--;
}
void up(int x,int father)
{
mn[x]=(flag[x]?MP(0,x):MP(INF,x));
for (auto v:E[x])
{
if (v==father) continue;
up(v,x),upmin(mn[x],MP(mn[v].fi+dep[v]-dep[x],mn[v].se));
}
}
void down(int x,int father)
{
for (auto v:E[x])
if (v!=father) upmin(mn[v],MP(mn[x].fi+dep[v]-dep[x],mn[x].se)),down(v,x);
}
void tree_dp(int x,int father)
{
for (auto v:E[x])
if (v!=father) tree_dp(v,x);
g[x]=sz[x];
for (auto v:E[x])
{
int t=jump(v,dep[x]+1); g[x]-=sz[t];
if (mn[x].se==mn[v].se) {
f[mn[x].se]+=sz[t]-sz[v]; continue; }
int mid=v;
for (int i=Log[dep[v]];i>=0;i--)
{
int p=fa[mid][i];
if (dep[p]<=dep[x]) continue;
if (MP(dep[p]-dep[x]+mn[x].fi,mn[x].se)>MP(dep[v]-dep[p]+mn[v].fi,mn[v].se)) mid=p;
}
f[mn[x].se]+=sz[t]-sz[mid];
f[mn[v].se]+=sz[mid]-sz[v];
}
f[mn[x].se]+=g[x];
}
void clean(int x,int father)
{
for (auto v:E[x]) if (v!=father) clean(v,x);
f[x]=g[x]=0,E[x].clear();
}
void clear()
{
for (int i=1;i<=m;i++) flag[a[i]]=0;
clean(1,0),top=0;
}
signed main()
{
n=read();
for (int i=1,u,v;i<n;i++) u=read(),v=read(),e[u].PB(v),e[v].PB(u);
Init();
int Case=read();
while (Case--)
{
m=read();
for (int i=1;i<=m;i++) a[i]=b[i]=read(),flag[a[i]]=1;
build(),up(1,0),down(1,0),tree_dp(1,0);
for (int i=1;i<=m;i++) printf("%d ",f[b[i]]); puts("");
clear();
}
return 0;
}