给定n个结点的树,定义G(k)为n个结点的图,u,v之间有边当且仅当u,v在树上的距离大于等于k。对任意k(1 <= k <= n), 求G(k)连通分量的个数

题目

思路:

给定n个结点的树,定义G(k)为n个结点的图,u,v之间有边当且仅当u,v在树上的距离大于等于k。对任意k(1 <= k <= n), 求G(k)连通分量的个数_第1张图片

代码优化:

在找直径端点的时候把其他结点到直径两端点的距离都求出来,不用lca

#include 
using namespace std;
#define int long long
#define pb push_back
#define fi first
#define se second
#define lson p << 1
#define rson p << 1 | 1
const int maxn = 1e6 + 5, inf = 1e18, maxm = 4e4 + 5, mod = 1e9 + 7, N = 1e6;
int a[maxn], b[maxn];
// int k[maxn];
// bool vis[maxn];
int n, m;
string s;
vector G[maxn];
int p[maxn];
int fa[21][maxn], dep[maxn];
int ans[maxn];

int find(int x){
	if(x == p[x]) return x;
	return p[x] = find(p[x]);
}

void dfs(int u, int father){
    fa[0][u] = father;
	for(int i = 1; i <= 20; i++){
		if((1 << i) > dep[u]) break;
		int t = fa[i - 1][u];
		fa[i][u] = fa[i - 1][t];
	}
	for(auto v : G[u]){
		if(v == father) continue;
		dep[v] = dep[u] + 1;
		dfs(v, u);
	}
}

void dfs2(int u, int father){
	for(auto v : G[u]){
		if(v == father) continue;
		dep[v] = dep[u] + 1;
		dfs2(v, u);
	}
}

int get(){
	int mx = -1, id = 0;
	for(int i = 1; i <= n; i++){
		if(dep[i] > mx){
			mx = dep[i];
			id = i;
		}
	}
	return id;
}

int lca(int x, int y){
	if(dep[x] < dep[y]) swap(x, y);
	int dif = dep[x] - dep[y];
	for(int i = 0; i < 20; i++){
		if((dif >> i) & 1) x = fa[i][x];
	}
	if(x == y) return x;
	for(int i = log2(dep[x]); i >= 0; i--){
		if(fa[i][x] != fa[i][y]){
			x = fa[i][x];
			y = fa[i][y];
		}
	}
	return fa[0][x];
}

int dis(int u, int v){
	int Lca = lca(u, v);
	return dep[u] + dep[v] - 2 * dep[Lca];
}

void solve(){
    int res = 0;
    // int k;
    int x;
    int q;
    cin >> n;
	for(int i = 1; i <= n; i++){
		p[i] = i;
	}
	int tot = n;
	for(int i = 1; i < n; i++){
		int u, v;
		cin >> u >> v;
		G[u].pb(v);
		G[v].pb(u);
	}
	dep[1] = 0;
	dfs2(1, 1);
	int st = get();
	dep[st] = 0;
	dfs2(st, st);
	int ed = get();
	// cout << st << ' ' << ed << '\n';
	dep[1] = 0;
	dfs(1, 1);
	vector> vec;
	for(int i = 1; i <= n; i++){
		vec.pb({max(dis(i, st), dis(i, ed)), i});
	}
	sort(vec.begin(), vec.end(), greater>());
	int j = 0;
	res = n;
	int d = dis(st, ed);
	// cout << d << '\n';
	for(int i = d + 1; i <= n; i++){
		ans[i] = n;
	}
	p[ed] = st;
	res = n - 1;
	for(int k = d; k >= 1; k--){
		while(j < n && vec[j].first >= k){
			int id = vec[j].second, dis = vec[j].first;
			if(find(id) != find(st)){
				res--;
				p[find(id)] = find(st);
			}
			j++;
			
		}
		ans[k] = res;
	}
	for(int i = 1; i <= n; i++){
		cout << ans[i] << " \n"[i == n];
	}
}
    
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    int T = 1;
    // cin >> T;
    while (T--)
    {
        solve();
    }
    return 0;
}

你可能感兴趣的:(codeforces,算法)