POJ 3417 Network(进阶指南, 树上差分)

算法竞赛进阶指南,380页,树上差分

本题要点:
1、 附加边 (x, y) 把 x, y之间的路径上的每条边都"覆盖了一次", 需要统计每条主要边被覆盖了几次;
a) 主要边被覆盖0次,那么将该主要边打断,然后随意打断一条附加边即可;
b) 主要边被覆盖1次,那么将该主要边打断,只有唯一的 附加边被打断才能得到不连通的两部分;
c) 主要边被覆盖2次及以上,无论如何操作都不能打败 Dark

2、 这里涉及的是树上边的差分:
给每个节点初始为0的权值,然后对每条附加边(x, y), 使x点的权值加1, y节点的权值加1, LCA(x, y) 的权值减1;
深度递归, F[x] 表示以x为根的子树的各个几点的权值之和,那么 F[x] 就是x和它的父节点之间的树边的权值。

树上差分参考: https://www.cnblogs.com/TEoS/p/11376676.html
#include 
#include 
#include 
#include 
#include 
using namespace std;
const int MaxN = 100010;
int ver[MaxN * 2], head[MaxN * 2], Next[MaxN * 2];
long long edge[MaxN * 2];
int f[MaxN][20];	//f[i][k] 表示 点i 向上走 2^k 步到达的节点号(n个节点,节点编号从1到n)
int d[MaxN];		//d[i] 点i的深度
long long val[MaxN];		// val[i] 表示第i点的权值
bool vis[MaxN];
int n, m, depth;	//depth 表示树的深度
int tot;

void add(int x, int y, int z)
{
	ver[++tot] = y, edge[tot] = z, Next[tot] = head[x], head[x] = tot;	
}

void bfs()
{
	d[1] = 1;
	queue<int> q;
	q.push(1);
	while(!q.empty())
	{
		int x = q.front();
		q.pop();
		for(int i = head[x]; i; i = Next[i])
		{
			int y = ver[i];
			if(d[y])
			{
				continue;
			}
			d[y] = d[x] + 1;
			f[y][0] = x;	//点y的父节点是x
			for(int t = 1; t <= depth; ++t)
			{
				f[y][t] = f[f[y][t - 1]][t - 1];	
			}
			q.push(y);
		}
	}
}

int lca(int x, int y)	//节点x和y的最近公共祖先
{
	if(d[x] > d[y])
	{
		swap(x, y);	//使得 d[x] <= d[y], 然后调整 y
	}
	for(int t = depth; t >= 0; --t)//这个循环的目的是,使得y和x处于同一高度
	{
		if(d[f[y][t]] >= d[x])
		{
			y = f[y][t];
		}
	}
	if(x == y)
	{
		return x;
	}
	for(int t = depth; t >= 0; --t)
	{
		if(f[x][t] != f[y][t])
		{
			x = f[x][t], y = f[y][t];
		}
	}
	return f[x][0];
}

void dfs(int x)
{
	vis[x] = true;
	int sum = 0;
	for(int i = head[x]; i; i = Next[i])
	{
		int y = ver[i];
		if(vis[y])
		{
			continue;
		}
		dfs(y);
		sum += val[y];
		edge[i] = val[y];
	}
	val[x] += sum;
}

int main()
{
	int x, y;
	scanf("%d%d", &n, &m);
	depth = (int)(log(n) / log(2)) + 1;
	tot = 0;
	for(int i = 1; i <= n; ++i)
	{
		vis[i] = val[i] = head[i] = d[i] = 0;
	}
	for(int i = 1; i < n; ++i)
	{
		scanf("%d%d", &x, &y);		
		add(x, y, 0), add(y, x, 0);
	}
	bfs();
	for(int i = 0; i < m; ++i)
	{
		scanf("%d%d", &x, &y);
		val[x]++, val[y]++;
		val[lca(x, y)] -= 2;
	}
	dfs(1);
	long long ans = 0;
	for(int i = 2; i <= n; ++i)
	{
		if(0 == val[i])
		{
			ans += m;	
		}else if(1 == val[i]){
			++ans;
		}
	}
	printf("%lld\n", ans);
	return 0;
}

/*
4 1
1 2
2 3
1 4
3 4
*/

/*
3
*/

你可能感兴趣的:(算法竞赛进阶指南,图论,POJ)