你是否遇到过下面这种问题:
SDOI2011 消耗战
在一场战争中,战场由 n n n 个岛屿和 n − 1 n-1 n−1 个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为1的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他 k k k 个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。
侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到1号岛屿上)。不过侦查部门还发现了这台机器只能够使用m次,所以我们只需要把每次任务完成即可。
对于 100 % 100\% 100% 的数据, 2 < = n < = 250000 , m > = 1 , ∑ k i < = 500000 , 1 < = k i < = n − 1 2<=n<=250000,m>=1,\sum k_i<=500000,1<=k_i<=n-1 2<=n<=250000,m>=1,∑ki<=500000,1<=ki<=n−1
很明显,这是一道树形 DP 的题,但是如果我们对于每次任务都重新做一次树形DP,就会超时。
这时候我们想,我们每次重新树形 DP,有很多点是没有任何用的, 我们可以只将有用的点提出来,重新建棵树,在做树形 DP。
这就是虚树的基本思想。
我们先来看看原树的 DP 如何实现:
f i = { m i n ( ∑ j ∈ s o n i f j , m i n d i ) i 不是资源点 m i n d i i 是资源点 f_i = \left \{ \begin{array}{ll} min(\sum_{j \in son_i} f_ j, mind_i) &i不是资源点\\ mind_i &i是资源点\\ \end{array} \right. fi={min(∑j∈sonifj,mindi)mindii不是资源点i是资源点
其中 f i f_i fi 表示第 i i i 个子树中所有资源点到根节点都断开的最小代价, m i n d i mind_i mindi 表示点 i i i 到根节点的所有的边的最小值。
答案就是 f 1 f_1 f1。
因为其中有很多节点删去,然后将它的儿子连向它的父亲之后并不会对答案产生任何影响。(在 m i n d i mind_i mindi 已经求出来的情况下)
所以我们可以将所有没用的节点删去,将有用的节点保留下来,连成一棵树,再在树上做 DP,此时的答案跟在原树的答案是一样的。
接下来,我们看看哪些节点是对我们有用的。
首先就是每一个资源点。(其实并不是所有的资源点,如果一个资源点的祖先有资源点,则该节点可不用考虑,因为如果删的边在该节点到祖先的资源点之间,则该资源点的祖先还要再删,所以可以不用考虑当前的节点)
然后就是资源点之间的 LCA。
为什么是资源点之间的 LCA 呢?
我们可以看回 DP 式。
f i = { m i n ( ∑ j ∈ s o n i f j , m i n d i ) i 不是资源点 m i n d i i 是资源点 f_i = \left \{ \begin{array}{ll} min(\sum_{j \in son_i} f_ j, mind_i) &i不是资源点\\ mind_i &i是资源点\\ \end{array} \right. fi={min(∑j∈sonifj,mindi)mindii不是资源点i是资源点
因为我们保证该子树没有可以到根节点的路径的话不是删除当前子树的根节点到父亲的边,而是当前子树的根节点到原树的根节点的所有边的最小值。
所以保存 LCA 一定不会比保存 LCA 的祖先更劣。
于是我们就可以将所有有用的点取出来,按照原本的祖先顺序连成一棵树,再在树上做 DP 就可以了。
我们看看如何将所有有用的点拉出来重新建一棵树。
首先我们求出原树的 dfs 序。
然后我们用一个栈来存储,现将根节点入栈。
我们将所有的资源点按 dfs 序排序,从小到大入栈。
对于 i i i 节点入栈,计 i i i 节点跟栈顶的节点的 LCA 为 l c a lca lca,我们分一下几种情况:
对于所有节点做完之后,如果栈里面还有节点,则一一弹出,并将相邻两个节点连边。
单看是很难理解的, 最好画图模拟一下,可以加深理解。
可以参考下下面的代码。
void build(){
tot=0;
stk[++tot]=1;
for(int i=1;i<=m;i++){
if(tot==1)
stk[++tot]=a[i];
int lca=LCA(stk[tot],a[i]);
if(lca==stk[tot])
continue;
while(dfn[lca]<=dfn[stk[tot - 1]] && tot>1)
addedge(stk[tot - 1],stk[tot]),tot--;
if(lca != stk[tot])
addedge(lca,stk[tot]),stk[tot]=lca;
stk[++tot]=a[i];
}
while(tot>1)
addedge(stk[tot - 1],stk[tot]),tot--;
}
到这里,整个过程就完成了。
下面是开头给出的题目的代码:
#include
using namespace std;
typedef long long LL;
namespace solve{
int Ecnt,last[500005],m,dfn[500005],fa[500005][35],d[500005],a[500005];
LL mind[500005];
int stk[500005],tot;
struct Edge { int to,next;} E[500005];
bool cmp(int a,int b){ return dfn[a]<dfn[b];}
void addedge(int u,int v){ Ecnt++,E[Ecnt].to=v,E[Ecnt].next=last[u],last[u]=Ecnt;}
int LCA(int x,int y){
if(d[x]<d[y])
swap(x,y);
for(int i=20;~i;i--)
if(d[fa[x][i]]>=d[y])
x=fa[x][i];
if(x==y)
return x;
for(int i=20;~i;i--)
if(fa[x][i] !=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void build(){
tot=0;
stk[++tot]=1;
for(int i=1;i<=m;i++){
if(tot==1)
stk[++tot]=a[i];
int lca=LCA(stk[tot],a[i]);
if(lca==stk[tot])
continue;
while(dfn[lca]<=dfn[stk[tot - 1]] && tot>1)
addedge(stk[tot - 1],stk[tot]),tot--;
if(lca != stk[tot])
addedge(lca,stk[tot]),stk[tot]=lca;
stk[++tot]=a[i];
}
while(tot>1)
addedge(stk[tot - 1],stk[tot]),tot--;
}
LL dp(int x){
if(last[x]==0)
return mind[x];
LL sum=0;
for(int xy=last[x];xy;xy=E[xy].next)
sum+=dp(E[xy].to);
last[x]=0;
return min(sum,mind[x]);
}
void main(){
scanf("%d",&m);
for(int i=1;i<=m;i++)
scanf("%d",&a[i]);
sort(a+1,a+1+m,cmp);
build();
printf("%lld\n",dp(1));
Ecnt=0;
}
}
int n,Q,Ecnt,cnt,last[500005];
struct Edge { int to,next;LL val;} E[500005];
void addedge(int u,int v,LL w){ Ecnt++,E[Ecnt].to=v,E[Ecnt].next=last[u],last[u]=Ecnt,E[Ecnt].val=w;}
void get_dfn(int x){
solve::dfn[x]=++cnt;
for(int xy=last[x];xy;xy=E[xy].next)
if(E[xy].to !=solve::fa[x][0]){
solve::mind[E[xy].to]=min(solve::mind[x],E[xy].val),solve::d[E[xy].to]=solve::d[x]+1;
solve::fa[E[xy].to][0]=x;
get_dfn(E[xy].to);
}
}
void get_fa(){
for(int i=1;i<=20;i++)
for(int j=1;j<=n;j++)
solve::fa[j][i]=solve::fa[solve::fa[j][i - 1]][i - 1];
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)
solve::mind[i]=1e16;
LL w;
for(int i=1,u,v;i<n;i++)
scanf("%d%d%lld",&u,&v,&w),addedge(u,v,w),addedge(v,u,w);
get_dfn(1);
get_fa();
scanf("%d",&Q);
while(Q--)
solve::main();
return 0;
}