Description
给定一棵 \(n\) 个点的无根树,有 \(q\) 个询问,每次询问将一些点设置为关键点,要求求出每个关键点管辖了多少个节点。\(a\) 管辖 \(b\) 当且仅当 \(a\) 是距离 \(b\) 最近的关键点或是最近的关键点中编号最小的那一个。\(n\leq 3\cdot 10^5,q\leq 3\cdot 10^5,\sum len\leq 3\cdot 10^5\)。
Sol
看见这个 \(\sum len\) 那就肯定是要在虚树上面乱搞了。
首先建出虚树,然后可以通过换根\(\mathrm{dfs}\)求出虚树上每个点被哪个关键点管辖。
然后就要统计不在虚树上的那些点对答案的贡献了。
不在虚树上的点分为两部分,一是 在虚树上的某条边中,二是 在虚树上的某个点的某个子树中(这个子树是原树里的子树)
那就可以愉快的统计了。
记录 \(sze[i]\) 表示原树中子树 \(i\) 的大小,\(siz[i]\) 表示 \(\sum\limits_{(i,j)\text{not in tree}}sze[j]\),即所有没有在虚树中出现的点 \(i\) 的儿子的子树和。可以注意到,管辖 \(siz[i]\) 这些点的关键点一定管辖点 \(i\)。
然后就要求,在虚树上某条边中的未出现的点的贡献了。
对于一条虚树上的边 \((x,y)\),首先找到 \(x\) 沿着这条链的第一个孩子 \(s\) ,如果管辖两端点的关键点一样,那直接将 \(sze[s]-sze[y]\) 加进该关键点的答案内,表示这条链上所有的点以及所有点的子树都会被相同的关键点管辖。否则,就要在链上二分出来一个 \(mid\),表示 \(mid\) 以及向下的点都被管辖 \(y\) 的关键点覆盖,\(mid\) 向上的点都被管辖 \(x\) 的关键点覆盖。两个答案分别加上 \(sze[s]-sze[mid],sze[mid]-sze[y]\) 即可。
Code
先记录一下如何建虚树,就是维护一个单调栈,栈内元素深度单调递增,也就是维护了虚树的一条链。然后每次新加入点的时候判断一下,如果这条链走到了头就得往外弹栈,具体看代码,比较好理解。
void ins(int x){
if(top<=1) return stk[++top]=x,void();
int LCA=lca(stk[top],x);
if(LCA==stk[top]) return stk[++top]=x,void();
while(top>1 and dfn[stk[top-1]]>=dfn[LCA])
add(stk[top-1],stk[top]),top--;
if(LCA!=stk[top]) add(LCA,stk[top]),stk[top]=LCA;
stk[++top]=x;
}
然后是这道题的代码。
#pragma GCC optimize(2)
#include
using namespace std;
typedef double db;
typedef long long ll;
typedef pair pii;
const int N=3e5+5;
pii bel[N];
int is[N],stk[N],top,tot,f[N][20];
int lg[N],d[N],sze[N],siz[N],ans[N];
int n,m,a[N],cnt,dfn[N],head[N],b[N];
struct Edge{
int to,nxt;
}edge[N<<1];
bool cmp(int x,int y){
return dfn[x]=d[y]) x=f[x][j];
if(x==y) return x;
for(int j=lg[d[x]];~j;j--)
if(f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
return f[x][0];
}
void ins(int x){
if(top<=1) return stk[++top]=x,void();
int LCA=lca(stk[top],x);
if(LCA==stk[top]) return stk[++top]=x,void();
while(top>1 and dfn[stk[top-1]]>=dfn[LCA])
add(stk[top-1],stk[top]),top--;
if(LCA!=stk[top]) add(LCA,stk[top]),stk[top]=LCA;
stk[++top]=x;
}
void dfs1(int now){
if(is[now]==m) bel[now]=pii(0,now);
else bel[now]=pii(1e9,0);
siz[now]=sze[now]-1;
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
dfs1(to);
bel[now]=min(bel[now],pii(bel[to].first+d[to]-d[now],bel[to].second));
}
}
void dfs2(int now,int fa){
if(now!=1)
bel[now]=min(bel[now],pii(bel[fa].first+d[now]-d[fa],bel[fa].second));
ans[bel[now].second]++;
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
dfs2(to,now);
}
}
int dis(int x,int y){
return d[x]+d[y]-2*d[lca(x,y)];
}
void dfs3(int now){
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
dfs3(to);
int son=to,mid=to;
for(int j=lg[d[son]];~j;j--)
if(d[f[son][j]]>d[now]) son=f[son][j];
siz[now]-=sze[son];
if(bel[now].second==bel[to].second){ans[bel[now].second]+=sze[son]-sze[to];continue;}
for(int j=lg[d[mid]];~j;j--){
int p=f[mid][j];
if(d[p]<=d[now]) continue;
if(pii(dis(p,bel[to].second),bel[to].second)1) add(stk[top-1],stk[top]),top--;
dfs1(1),dfs2(1,0),dfs3(1);
for(int i=1;i<=len;i++) printf("%d ",ans[b[i]]);puts("");
}
signed main(){
scanf("%d",&n);
for(int x,y,i=1;i