虚树是指在原树上选择若干点组成的树,它在原树的基础上做了一些简化,但是保留必要的信息,从而使得计算更加高效。
虚树主要用于树形DP中,能够减少顶点数,降低时间复杂度。
题目传送门
给出一棵树,n个顶点。每条边有边权。有m次询问,每次询问给出k个询问点,问使得这
k个点均不与1号点(根节点)相连的最小代价。
这道题可以用树形dp来做:
显然:我们设 d p [ i ] dp[i] dp[i]表示以 i i i为根的子树内满足题意的最小代价, m i n d [ i ] mind[i] mind[i]表示点 i i i到根节点的路径中最小边权是多少
1、点 i i i是询问点: d p i = m i n d i dp_i = mind_i dpi=mindi
2、点 i i i不是询问点: d p i = m i n ( m i n d i , ∑ d p j ( j 是 i 的儿子子树 j 里面有询问点 ) ) dp_i = min (mind_i , \sum dp_j(j是i的儿子子树j里面有询问点)) dpi=min(mindi,∑dpj(j是i的儿子子树j里面有询问点))
分析一下时间复杂度,发现是 O(n*m) 的,好像过不了
此时我们就要想一想 优化
因为我们每次其实只会考虑那些询问点,而询问点的数量满足 k < = 500005 k<=500005 k<=500005 ,所以我们花费了大多数时间在跑那些没有意义的节点。
所以我们需要 重建一棵树,使得树上所有节点都是有意义的,这就是虚树 即只包含所有的询问点和他们的 l c a lca lca
首先要给原树做一次 d f s dfs dfs,打上时间戳 d f n dfn dfn
然后搞一个栈,里面存的是一条链。
先给1号节点入栈, s t k [ + + t o p ] = 1 stk[++top] = 1 stk[++top]=1
然后考虑当前待加入的节点 p ( p 是询问点 ) p(p是询问点) p(p是询问点)
如果 s t k [ t o p ] = = l c a ( p , s t k [ t o p ] ) stk[top] == lca (p , stk[top]) stk[top]==lca(p,stk[top]) , 那么一定满足条件,直接进栈,$ stk[++top] = p$
否则, p 和 l c a ( p , s t k [ t o p ] ) p和lca(p , stk[top]) p和lca(p,stk[top])一定是不同子树里面的,则: w h i l e ( d f n [ l c a ( s t k [ t o p ] , p ) ] < = d f n [ s t k [ t o p − 1 ] ] ) 弹出 t o p − − while(dfn[lca(stk[top] , p)] <= dfn[stk[top - 1]]) \ 弹出 top -- while(dfn[lca(stk[top],p)]<=dfn[stk[top−1]]) 弹出top−−
如果 d f n [ l c a ( s t k [ t o p ] , p ) ] ! = d f n [ s t k [ t o p ] ] dfn[lca(stk[top] , p)] != dfn[stk[top]] dfn[lca(stk[top],p)]!=dfn[stk[top]]那么要将 l c a 和 p lca和p lca和p分别入栈
最后要把栈清空
记住:每次退出时都要在虚树上连边( s t k [ t o p ] , s t k [ t o p − 1 ] stk[top] , stk[top - 1] stk[top],stk[top−1])
#include
#define LL long long
using namespace std;
const int N = 250005 , M = 5e5 + 5;
int p1[N] , p , dep[N] , top[N] , sz[N] , son[N] , hd[N] , cnt , fa[N] , tp , stk[N] , fg[N] , m , h[M] , k , u , v , n , h1;
LL dp[N] , mind[N] , wi;
bool comp (int x , int y) {
return p1[x] < p1[y];
}
struct E {
int to , nt ;
long long w;
} e[N * 2];
void add (int x , int y , long long z) {
e[++cnt].to = y;
e[cnt].nt = hd[x];
e[cnt].w = z;
hd[x] = cnt;
}
inline void dfs1 (int x) {
p1[x] = ++p;
int y , maxs = 0;
sz[x] = 1;
for (int i = hd[x] ; i ; i = e[i].nt) {
y = e[i].to;
if (y == fa[x])
continue;
dep[y] = dep[x] + 1;
fa[y] = x;
mind[y] = min (mind[x] , e[i].w);
dfs1 (y);
sz[x] += sz[y];
if (sz[y] > maxs) {
maxs = sz[y];
son[x] = y;
}
}
}
inline void dfs2 (int x) {
int y;
if (son[x]) {
top[son[x]] = top[x];
dfs2 (son[x]);
}
else
return;
for (int i = hd[x] ; i ; i = e[i].nt) {
y = e[i].to;
if (y == fa[x] || y == son[x])
continue;
top[y] = y;
dfs2 (y);
}
}
int lca (int x , int y) {
while (top[x] != top[y]) {
if (dep[fa[top[x]]] > dep[fa[top[y]]])
swap(x , y);
y = fa[top[y]];
}
return dep[x] < dep[y] ? x : y;
}
inline void Insert (int x) {
int y = stk[tp];
int Lca = lca (x , y);
while (dep[y] > dep[Lca] && tp) {
if (dep[stk[tp - 1]] < dep[Lca]) {
add (Lca , y , 0);
}
else
add (stk[tp - 1] , stk[tp] , 0);
tp--;
y = stk[tp];
}
if (y == Lca) {
stk[++tp] = x;
}
else {
stk[++tp] = Lca;
stk[++tp] = x;
}
}
inline void dfs3 (int x) {
int y;
long long sum = 0;
for (int i = hd[x] ; i ; i = e[i].nt) {
y = e[i].to;
if (y == fa[x])
continue;
dfs3 (y);
sum += dp[y];
}
if (fg[x]) {
dp[x] = mind[x];
}
else {
dp[x] = min (mind[x] , sum);
}
hd[x] = 0;
}
inline void solve () {
for (int i = 1 ; i <= m ; i++) {
scanf ("%d" , &k);
for (int j = 1 ; j <= k ; j++){
scanf ("%d" , &h[j]);
}
cnt = 0;
sort(h + 1 , h + k + 1 , comp);
stk[1] = 1;
tp = 1;
for (int j = 1 ; j <= k ; j++) {
Insert (h[j]);
fg[h[j]] = 1;
}
for (int i = tp ; i > 1 ; i--) {
add(stk[i - 1] , stk[i] , 0);
}
dfs3 (1);
for (int j = 1 ; j <= k ; j++) {
fg[h[j]] = 0;
}
printf ("%lld\n" , dp[1]);
}
}
int main () {
scanf ("%d" , &n);
for (int i = 1 ; i < n ; i++) {
scanf ("%d%d%lld" , &u , &v , &wi);
add (u , v , wi) , add (v , u , wi);
}
dep[1] = 1;
for(int i=1;i<=n;i++) mind[i]=1e18;
dfs1 (1);
dfs2 (1);
cnt = 0;
memset(hd , 0 , sizeof(hd));
scanf ("%d" , &m);
solve ();
return 0;
}