传送门
首先建出虚树来,边权即为原树上的距离
这题我dp的方法非常蠢
f(i)表示从i的父边出去(必须经过i的父亲)到达的关键点的最短路
fp(i)表示最短路的点
g(i)表示i到i的子树中到达的关键点的最短路
gp(i)表示最短路的点
然后这两个互相转移一下…
dp完了之后枚举虚树上的每一条边(u,v)
因为已经知道了从u出去到关键点的最短路和从v出去到关键点的最短路
然后就可以计算出这条边上的哪些点归到u的关键点去,哪些归到v的关键点去
嫌麻烦写了二分…
不过恶心的是统计答案
因为建出来的虚树只包含了关键点和其lca,在每一个点和每一条边上都有可能有若干没有关键点的子树
维护一个cnt表示虚树上每一个点所连的非关键点有多少个(包括子树)
然后根据原树的size和维护出的cnt乱搞…
调了两节课…午饭都没吃…恶心死我了
可能根本原因是我太蠢了
#include
#include
#include
#include
#include
using namespace std;
#define N 600005
#define sz 19
int n,q,m,dfs_clock,inf;
int tot,point[N],nxt[N],v[N],last[N];
int in[N],out[N],h[N],size[N],cnt[N],stack[N],top,st[N][sz+3],key[N],flag[N];
int qu[N],f[N],fp[N],g[N],gp[N],Son[N],pt[N],goal[N],ans[N];
void add(int x,int y)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
}
void build(int x,int fa)
{
in[x]=++dfs_clock;h[x]=h[fa]+1;size[x]=1;
for (int i=1;i1]][i-1];
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
st[v[i]][0]=x;
build(v[i],x);
size[x]+=size[v[i]];
}
out[x]=dfs_clock;
}
int cmp(int a,int b)
{
return in[a]<in[b];
}
int lca(int x,int y)
{
if (h[x]int k=h[x]-h[y];
for (int i=0;iif ((k>>i)&1) x=st[x][i];
if (x==y) return x;
for (int i=sz-1;i>=0;--i)
if (st[x][i]!=st[y][i]) x=st[x][i],y=st[y][i];
return st[x][0];
}
void sontofa(int x)
{
if (key[x]==q) g[x]=0,gp[x]=x;
for (int i=point[x];i;i=nxt[i])
{
sontofa(v[i]);
if (g[v[i]]+last[v[i]]else if (g[v[i]]+last[v[i]]==g[x]&&gp[v[i]]int find(int x,int k)
{
for (int i=0;iif ((k>>i)&1) x=st[x][i];
return x;
}
void calc(int len,int fa,int son)
{
int minson=g[son],minfa=f[son]-len;
if (minson&&minfa&&(len+minfaint p=find(son,len-1);
goal[fa]=goal[son]=fp[son];
ans[fp[son]]+=size[p]-size[son];
return;
}
if (minson&&minfa&&(minson+lenint p=find(son,len-1);
goal[fa]=goal[son]=gp[son];
ans[gp[son]]+=size[p]-size[son];
return;
}
int l=-1,r=len,mid,final=-1;
while (l<=r)
{
if (l==r&&l==-1)
{
final=-1;
break;
}
mid=(l+r)>>1;
if (mid+minson1;
else r=mid-1;
}
if (final==len)
{
int p=find(son,len-1);
goal[fa]=goal[son]=gp[son];
ans[gp[son]]+=size[p]-size[son];
return;
}
if (final==-1)
{
int p=find(son,len-1);
goal[fa]=goal[son]=fp[son];
ans[fp[son]]+=size[p]-size[son];
return;
}
int p=find(son,len-1),t=find(son,final);
goal[fa]=fp[son];goal[son]=gp[son];
ans[fp[son]]+=size[p]-size[t];
ans[gp[son]]+=size[t]-size[son];
return;
}
void fatoson(int x)
{
if (key[x]==q)
{
for (int i=point[x];i;i=nxt[i])
f[v[i]]=last[v[i]],fp[v[i]]=x;
}
else
{
Son[0]=0;
for (int i=point[x];i;i=nxt[i]) Son[++Son[0]]=v[i];
int Min=inf,Minp=0;
for (int i=1;i<=Son[0];++i)
{
if (f[x]+last[Son[i]]else if (f[x]+last[Son[i]]==f[Son[i]]&&fp[x]if (Min+last[Son[i]]else if (Min+last[Son[i]]==f[Son[i]]&&Minpif (g[Son[i]]+last[Son[i]]else if (g[Son[i]]+last[Son[i]]==Min&&gp[Son[i]]0;
for (int i=Son[0];i>=1;--i)
{
if (Min+last[Son[i]]else if (Min+last[Son[i]]==f[Son[i]]&&Minpif (g[Son[i]]+last[Son[i]]else if (g[Son[i]]+last[Son[i]]==Min&&gp[Son[i]]for (int i=point[x];i;i=nxt[i])
{
fatoson(v[i]);
calc(last[v[i]],x,v[i]);
}
}
void solve()
{
sontofa(1);
fatoson(1);
for (int i=1;i<=pt[0];++i)
ans[goal[pt[i]]]+=cnt[pt[i]];
for (int i=1;i<=m;++i)
printf("%d ",ans[qu[i]]);puts("");
}
int main()
{
scanf("%d",&n);
for (int i=1;iint x,y;scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
build(1,0);
scanf("%d",&q);
memset(point,0,sizeof(point));
memset(g,127,sizeof(g));inf=g[0];
memset(f,127,sizeof(f));
while (q)
{
scanf("%d",&m);
for (int i=1;i<=m;++i)
{
scanf("%d",&pt[i]);
qu[i]=pt[i];
key[pt[i]]=flag[pt[i]]=q;
goal[pt[i]]=pt[i];
}
sort(pt+1,pt+m+1,cmp);pt[0]=m;
for (int i=2;i<=m;++i)
{
int r=lca(pt[i-1],pt[i]);
if (flag[r]!=q)
flag[r]=q,pt[++pt[0]]=r;
}
if (flag[1]!=q) flag[1]=q,pt[++pt[0]]=1;
sort(pt+1,pt+pt[0]+1,cmp);
tot=0;stack[top=1]=1;cnt[1]=size[1];
for (int i=2;i<=pt[0];++i)
{
while (in[pt[i]]<in[stack[top]]||in[pt[i]]>out[stack[top]])
--top;
add(stack[top],pt[i]);
last[pt[i]]=h[pt[i]]-h[stack[top]];
int r=find(pt[i],h[pt[i]]-h[stack[top]]-1);
cnt[stack[top]]-=size[r];
stack[++top]=pt[i];cnt[pt[i]]=size[pt[i]];
}
solve();
for (int i=1;i<=pt[0];++i)
{
int x=pt[i];
point[x]=fp[x]=gp[x]=cnt[x]=ans[x]=last[x]=goal[x]=0;
f[x]=g[x]=inf;
}
--q;
}
}