【bzoj3572】世界树 虚树&树形dp

       很明显的虚树题。

       把关键点的虚树构建出来,然后可以两遍遍历得到离点i最近的关键点的距离和编号。那么现在考虑虚树中的一对点(x,y),x为y的某个儿子,考虑其对答案的影响。

       由于是虚树,那么显然所有y->x的路径上的点,这个点延伸出去的点中(不包含由y->x的路径)不会有关键点存在,那么离这些点最近的虚树中的点,要么是x,要么是y,而且一定是先到达y->x的路径上的某一点,然后到达x或y,最后到达关键点。那么y->x的路径上有有一个分界点z,使得y->z的路径上延伸出去的点都由y到达关键点,z->x都由x到达关键点。而x到关键点的路径一定是往下走,y则一定是先往上走,因此求出关于x个y的两个关键点的中点,就是z了。

       然后用子树大小统计路径及其延伸出去的点的个数即可,最后要加上那些没有被统计到的点。

AC代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 300005
#define inf 1000000000
using namespace std;

int n,m,tot,dfsclk,bin[25],fst[N],pnt[N<<1],nxt[N<<1],fa[N][19],d[N],pos[N];
int a[N],p[N],id[N],val[N],anc[N],len[N],ans[N],sz[N],q[N]; struct node{ int x,y; }g[N];
int read(){
	int x=0; char ch=getchar();
	while (ch<'0' || ch>'9') ch=getchar();
	while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
	return x;
}
void add(int x,int y){
	pnt[++tot]=y; nxt[tot]=fst[x]; fst[x]=tot;
}
void dfs(int x){
	pos[x]=++dfsclk; sz[x]=1; int p,i;
	for (i=1; bin[i]<=x; i++) fa[x][i]=fa[fa[x][i-1]][i-1];
	for (p=fst[x]; p; p=nxt[p]){
		int y=pnt[p];
		if (y!=fa[x][0]){
			fa[y][0]=x; d[y]=d[x]+1;
			dfs(y); sz[x]+=sz[y];
		}
	}
}
int lca(int x,int y){
	if (d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y],i;
	for (i=0; bin[i]<=tmp; i++)
		if (tmp&bin[i]) x=fa[x][i];
	for (i=18; i>=0; i--)
		if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; }
	return (x==y)?x:fa[x][0];
}
int find(int x,int dep){
	int i; for (i=18; i>=0; i--) if (d[fa[x][i]]>=dep) x=fa[x][i];
	return x;
}
bool cmp(int x,int y){ return pos[x]<pos[y]; }
bool lss(node u,node v){
	return u.x<v.x || u.x==v.x && u.y<v.y;
}
void solve(){
	m=read(); int i,cnt=m,tp=0; node t;
	for (i=1; i<=m; i++){
		a[i]=id[i]=p[i]=read(); g[a[i]].y=a[i];
		g[a[i]].x=ans[a[i]]=0;
	}
	sort(a+1,a+m+1,cmp);
	for (i=1; i<=m; i++)
		if (!tp){ q[++tp]=a[i]; anc[a[i]]=0; } else{
			int tmp=lca(a[i],q[tp]);
			for (; d[q[tp]]>d[tmp]; tp--)
				if (d[q[tp-1]]<=d[tmp]) anc[q[tp]]=tmp;
			if (q[tp]!=tmp){
				p[++cnt]=tmp; anc[tmp]=q[tp];
				q[++tp]=tmp; g[tmp].x=inf; g[tmp].y=0;
			}
			anc[a[i]]=tmp; q[++tp]=a[i];
		}
	sort(p+1,p+cnt+1,cmp);
	for (i=1; i<=cnt; i++){
		int x=p[i]; val[x]=sz[x];
		if (i>1) len[x]=d[x]-d[anc[x]];
	}
	for (i=cnt; i>1; i--){
		int x=p[i]; t=g[x]; t.x+=len[x];
		if (lss(t,g[anc[x]])) g[anc[x]]=t;
	} 
	for (i=2; i<=cnt; i++){
		int x=p[i]; t=g[anc[x]]; t.x+=len[x];
		if (lss(t,g[x])) g[x]=t;
	}
	for (i=1; i<=cnt; i++){
		int x=p[i],y=anc[x];
		if (i==1) ans[g[x].y]+=n-sz[x]; else{
			int tmp=find(x,d[y]+1),sum=sz[tmp]-sz[x];
			val[y]-=sz[tmp];
			if (g[x].y==g[y].y) ans[g[x].y]+=sum; else{
				int z=d[x]-((g[y].x+len[x]-g[x].x)>>1);
				if (!((g[y].x+g[x].x+len[x])&1) && g[x].y>g[y].y) z++;
				z=sz[find(x,z)]-sz[x];
				ans[g[x].y]+=z; ans[g[y].y]+=sum-z;
			}
		}
	}
	for (i=1; i<=cnt; i++) ans[g[p[i]].y]+=val[p[i]];
	for (i=1; i<=m; i++) printf("%d ",ans[id[i]]); puts("");
}
int main(){
	n=read(); int i;
	bin[0]=1; for (i=1; i<=19; i++) bin[i]=bin[i-1]<<1;
	for (i=1; i<n; i++){
		int x=read(),y=read();
		add(x,y); add(y,x);
	}
	d[1]=1; dfs(1);
	int cas=read(); while (cas--) solve();
	return 0;
}


by lych

2016.3.7

你可能感兴趣的:(DFS,LCA,树形DP,虚树)