BZOJ 2286消耗战

题目描述:戳这里
题解:
这题就是虚树的板子题啦。
我们要求一些最小可以阻断给定k个点到根的路径的边权和。那么考虑没有询问的情况,可以直接用树形DP。我们先用倍增求出一个点到根的路径上的最小的边权x,然后对于一个选中的节点,肯定是求它的x作为一这个点为根的子树上的答案。而对于一个未被选中的点,就有两种可能,一种是合并它子树中的点的答案,一种是当做选中的点来处理,两个取min。
但是我们考虑到有多组询问,这样做显然会超时。
这时候,如果有一棵树只保存了选中点的信息和少量未选中点的信息,考虑到 ∑ k \sum k k比较小,就能解决这个问题。虚树就是为了完成这个操作而生的。因为考虑到两个选中点之间只有未选中点得路径是没有用的,而这段路径上的最小边权我们可以用倍增求出来。那么我们就考虑怎样将这颗“大树”缩成一颗“小树”。
那么应该如何维持这颗树的形态呢?
我们可以先按照dfs序给这些给定的点排个序,按照这个顺序处理,会方便很多。
我们维护一个栈。
假设我们按照顺序做到第i个点
我们可以比较i和i-1两个点在原树中的位置关系。
1.i在i-1的子树中,那么我们直接把i推到栈中,过后处理
2.i不在i-1的子树中,但在i之前所有点的lca的子树中,那么根据我们维护的栈性质以及dfs序,深度比这个点深的都可以踢出栈了。那么我们让它们按顺序相接(top和top-1建边)然后一一踢掉,然后为了维持正确的关系,我们不能让这个点直接进栈,而是让它和最后一个被踢出的点的lca先入栈,然后再把这个点入栈。
3.i不在lca的子树中,那么栈中的点显然对后面的建树没用了,那么直接踢掉栈中的所有点,然后将当前点和前面点的lca的lca入栈,然后当前点入栈。
最后记得把剩下的点按顺序连边,然后清空栈!
总结这三种情况,无非是先踢点再入栈,在代码中不用分类讨论。
按照这种方式建树,最多多加了k个未被选点,所以复杂度还是比较科学的。
那么接下来就只要在新树上DP一下就好了啊。
注意不能要动态存储新树的边,不然每次memset就Tle了。。。

代码如下:

#include
#define ll long long
using namespace std;
const int maxn=500005,maxm=21;
int n,m,m1,Q,tot,lnk[maxn],son[2*maxn],nxt[2*maxn],w[2*maxn],id[maxn];
int dep[maxn],f[maxn][maxm],minn[maxn][maxm],b[2*maxn],stk[4*maxn];
vector<int> a[maxn];
bool vis[maxn];
void add(int x,int y,int z){
	son[++tot]=y,w[tot]=z,nxt[tot]=lnk[x],lnk[x]=tot;
}
void dfs(int x){
	id[x]=++tot;
	for (int j=1;j<=20;j++)
	f[x][j]=f[f[x][j-1]][j-1],minn[x][j]=min(minn[x][j-1],minn[f[x][j-1]][j-1]);
	for (int j=lnk[x];j;j=nxt[j])
		if (!vis[son[j]]){
			vis[son[j]]=1,f[son[j]][0]=x,minn[son[j]][0]=w[j],dep[son[j]]=dep[x]+1;
			dfs(son[j]);
		}
}
int lca(int x,int y){
	if (dep[x]<dep[y]) swap(x,y);
	for (int j=20;j>=0;j--) if (dep[f[x][j]]>=dep[y]) x=f[x][j];
	if (x==y) return x;
	for (int j=20;j>=0;j--) if (f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
	return f[x][0];
}
int lcamin(int x,int fa){
	int ret=1<<30;
	for (int j=20;j>=0;j--)
		if (dep[f[x][j]]>=dep[fa])
			ret=min(ret,minn[x][j]),x=f[x][j];
	return ret;
}
int cmp(int x,int y){
	return id[x]<id[y];
}
void adde(int x,int y){
	a[x].push_back(y); a[y].push_back(x);
}
ll DFS(int x){
	ll ans=0;
	for (int i=0;i<a[x].size();i++)
	if (!vis[a[x][i]]){
		vis[a[x][i]]=1;
		ll tmp=DFS(a[x][i]);
		ans+=tmp;
	}
	if (x>m) {
		if (b[x]!=1) ans=min(1ll*lcamin(b[x],1),ans);
	} else ans=lcamin(b[x],1);
	return ans;
}
int main(){
	scanf("%d",&n);
	for (int i=1;i<n;i++){
		int x,y,w;
		scanf("%d%d%d",&x,&y,&w);
		add(x,y,w); add(y,x,w);
	}
	scanf("%d",&Q);
	for (int i=0;i<=20;i++) minn[1][i]=1<<30;
	dep[1]=1,tot=0,vis[1]=1; dfs(1);
	while (Q--){
		scanf("%d",&m);
		for (int i=1;i<=m;i++) scanf("%d",&b[i]);
		sort(b+1,b+1+m,cmp);
		if (m==1) {printf("%d\n",lcamin(b[1],1)); continue;}
		int top=0,m1=m; stk[++top]=1;
		for (int i=2;i<=m;i++){
			while (top>1){
				int LCA=lca(b[stk[top]],b[i]);
				if (dep[LCA]<dep[b[stk[top-1]]]) {
					adde(stk[top],stk[top-1]);
					top--;
				} else break;
			}
			int LCA=lca(b[stk[top]],b[i]);
			if (LCA!=b[stk[top]]){
				b[++m1]=LCA; adde(stk[top],m1);
				stk[top]=m1;
			}
			stk[++top]=i;
		}
		while (top>1) {adde(stk[top],stk[top-1]); top--;}
		for (int i=1;i<=m1;i++) vis[i]=0; vis[stk[top]]=1;
		printf("%lld\n",DFS(stk[top]));
		for (int i=0;i<=m1;i++) a[i].clear();
	}
	return 0;
}

你可能感兴趣的:(题解,BZOJ题解)