http://codeforces.com/problemset/problem/613/D
题意:
给出n个点的树,有q个询问,每次询问给出k个重要的点,问至少删掉多少个非重要的点,使得这个重要的点都不连通。
首先对于一个询问,无解的情况肯定是有两个相邻的重要点,直接特判就行。
对于其他情况,进行一次tree dp。
如果当前点为非重要点,且它的子树中有至少两个重要点,那么将这个点删掉。
如果当前点为重要点,对于其中一个儿子,如果这个儿子的子树中有重要点,那么将这个点的儿子删掉。
但是询问太多,就要用到虚树了。
考虑每次询问,有实际作用的点只有:重要点以及这些点的lca。那么我们对于每个询问,都建一个新图,只把这些点抽出来tree dp。
注意在新图中相邻的点不一定在原图中相邻。
关于建新图,当我们把倍增lca预处理好之后,原图就没有用了,所以我们把链式前向星清空,然后就是普通建图姿势了。
/* Footprints In The Blood Soaked Snow */ #include <cstdio> #include <vector> #include <algorithm> using namespace std; const int maxn = 100005, maxd = 18, inf = 0x3f3f3f3f; int n, head[maxn], cnt, st[maxn], ed[maxn], depth[maxn], fa[maxn][maxd], clo, sta[maxn], f[maxn][2]; bool imp[maxn]; struct _edge { int v, next; } g[maxn << 1]; inline int iread() { int f = 1, x = 0; char ch = getchar(); for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1; for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0'; return f * x; } inline void add(int u, int v) { g[cnt] = (_edge){v, head[u]}; head[u] = cnt++; } inline void dfs(int x) { st[x] = ++clo; for(int i = head[x]; ~i; i = g[i].next) if(g[i].v != fa[x][0]) { fa[g[i].v][0] = x; depth[g[i].v] = depth[x] + 1; dfs(g[i].v); } ed[x] = clo; } inline int getlca(int u, int v) { if(depth[u] < depth[v]) swap(u, v); for(int i = maxd - 1; i >= 0; i--) if(depth[fa[u][i]] >= depth[v]) u = fa[u][i]; for(int i = maxd - 1; i >= 0; i--) if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; return u == v ? u : fa[u][0]; } int node[maxn]; inline int dp(int x) { int tot = 0, ans = 0; for(int i = head[x]; ~i; i = g[i].next) { ans += dp(g[i].v); tot += node[g[i].v]; } if(imp[x]) { ans += tot; node[x] = 1; } else { ans += tot > 1; node[x] = tot == 1; } return ans; } inline bool cmp(int a, int b) { return st[a] < st[b]; } int main() { n = iread(); for(int i = 0; i < maxn; i++) head[i] = -1; cnt = clo = 0; for(int i = 1; i < n; i++) { int x = iread(), y = iread(); x--, y--; add(x, y); add(y, x); } depth[0] = fa[0][0] = 0; dfs(0); for(int j = 1; j < maxd; j++) for(int i = 0; i < n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1]; int T = iread(); while(T--) { vector<int> v; int m = iread(); for(int i = 0; i < m; i++) { int x = iread(); x--; v.push_back(x); imp[x] = 1; } bool flag = 0; for(int i = 0; i < v.size(); i++) if(fa[v[i]][0] != v[i] && imp[fa[v[i]][0]]) { flag = 1; break; } if(flag) { printf("-1\n"); for(int i = 0; i < v.size(); i++) imp[v[i]] = 0; continue; } sort(v.begin(), v.end(), cmp); for(int i = 1; i < m; i++) v.push_back(getlca(v[i - 1], v[i])); sort(v.begin(), v.end(), cmp); v.resize(unique(v.begin(), v.end()) - v.begin()); int top = 0; cnt = 0; for(int i = 0; i < v.size(); i++) { int x = v[i]; head[x] = -1; for(; top && !(st[sta[top]] <= st[x] && st[x] <= ed[sta[top]]); top--); if(top) add(sta[top], x); sta[++top] = x; } printf("%d\n", dp(v[0])); for(int i = 0; i < v.size(); i++) imp[v[i]] = 0; } return 0; }