LCA + 树状数组
要尽量删除少的点来满足答案,那么受限考虑两个点的LCA,这样删除以后可以获得最大收益即该点的子树任意一点无法到达子树外任意一点。
删除顺序的选择应该是先删除LCA深的较大的,可以画图试想一下。。先删除高的LCA,如果还有点对的LCA在其子树内,那我们还需要对删除一次,而如果我们直接先把LCA低的点对删除,另外哪个LCA高的点对自然也无法联通。
因为我们是按深度大小来一次删除点的,所以我们可以用树状数组来维护整个树的dfs序,如果删除了一个点,可以把它及他的子树全部+1,这样在保证深度递减的情况下,有点的值是1就代表无法到达。
#include
#define INF 0x3f3f3f3f
#define full(a, b) memset(a, b, sizeof a)
#define FAST_IO ios::sync_with_stdio(false), cin.tie(0), cout.tie(0)
using namespace std;
typedef long long ll;
inline int lowbit(int x){ return x & (-x); }
inline int read(){
int ret = 0, w = 0; char ch = 0;
while(!isdigit(ch)) { w |= ch == '-'; ch = getchar(); }
while(isdigit(ch)) ret = (ret << 3) + (ret << 1) + (ch ^ 48), ch = getchar();
return w ? -ret : ret;
}
inline int gcd(int a, int b){ return b ? gcd(b, a % b) : a; }
inline int lcm(int a, int b){ return a / gcd(a, b) * b; }
template
inline T max(T x, T y, T z){ return max(max(x, y), z); }
template
inline T min(T x, T y, T z){ return min(min(x, y), z); }
template
inline A fpow(A x, B p, C lyd){
A ans = 1;
for(; p; p >>= 1, x = 1LL * x * x % lyd)if(p & 1)ans = 1LL * x * ans % lyd;
return ans;
}
const int N = 20005;
int _, m, cnt, t, tot, head[N], in[N], out[N], p[N][30], depth[N], b[N];
struct Edge { int v, next; } edge[N<<1];
struct Query{
int u, v, lca;
bool operator < (const Query &rhs) const {
return depth[lca] > depth[rhs.lca];
}
}query[N];
void addEdge(int a, int b){
edge[cnt].v = b, edge[cnt].next = head[a], head[a] = cnt ++;
}
void build(){
cnt = tot = t = 0;
full(head, -1), full(in, 0), full(out, 0);
full(depth, 0), full(b, 0);
}
void dfs(int s, int fa){
in[s] = ++tot, p[s][0] = fa;
depth[s] = depth[fa] + 1;
for(int i = 1; i <= t; i ++) p[s][i] = p[p[s][i - 1]][i - 1];
for(int i = head[s]; i != -1; i = edge[i].next){
int u = edge[i].v;
if(u == fa) continue;
dfs(u, s);
}
out[s] = ++tot;
}
int lca(int x, int y){
if(depth[x] < depth[y]) swap(x, y);
for(int i = t; i >= 0; i --){
if(depth[p[x][i]] >= depth[y]) x = p[x][i];
}
if(x == y) return y;
for(int i = t; i >= 0; i --){
if(p[x][i] != p[y][i]) x = p[x][i], y = p[y][i];
}
return p[y][0];
}
void insert(int xi, int val){
for(; xi <= tot; xi += lowbit(xi)) b[xi] += val;
}
int solve(int xi){
int ret = 0;
for(; xi; xi -= lowbit(xi)) ret += b[xi];
return ret;
}
int main(){
while(~scanf("%d", &m)){
build();
t = (int)(log(m + 1) / log(2)) + 1;
//t = 20;
for(int i = 1; i <= m; i ++){
int u = read() + 1, v = read() + 1;
addEdge(u, v), addEdge(v, u);
}
dfs(1, 0);
int q = read();
for(int i = 1; i <= q; i ++){
query[i].u = read() + 1, query[i].v = read() + 1;
query[i].lca = lca(query[i].u, query[i].v);
}
sort(query + 1, query + q + 1);
int ans = 0;
for(int i = 1; i <= q; i ++){
int u = query[i].u, v = query[i].v, f = query[i].lca;
if(solve(in[u]) + solve(in[v])) continue;
ans ++, insert(in[f], 1), insert(out[f] + 1, -1);
}
printf("%d\n", ans);
}
return 0;
}