题解 CF888G Xor-MST【01-Trie Boruvka】

题目链接

题意

有一 n n n 个点的完全图,点有点权 a i a_i ai,边 ( i , j ) (i,j) (i,j) 有边权 a i ⊕ a j a_i\oplus a_j aiaj(异或),求其最小生成树。 n ≤ 2 × 1 0 5 n\leq 2\times10^5 n2×105

题解

最小生成树有一种冷门算法叫做 Boruvka。其大致思想是:

  • 初始有 n n n 个点各自为一个连通块,形成一个森林;
  • 反复进行以下操作直到森林变成树:
    • 找到每个连通块连向其外部的最小边;
    • 将这些边都连上(有重复也不必管)。

显然每次操作是 O ( n + m ) O(n+m) O(n+m) 的,而每次操作会使连通块个数至少减半,因此时间复杂度是 O ( ( n + m ) log ⁡ n ) O((n+m)\log n) O((n+m)logn)

这种算法尤其适用于边权由点权计算得到的情况(比如此题),因为每个连通块连向其外部的最小边可以用数据结构维护或者其他奇奇怪怪的方法快速得到,时间复杂度也就降低了。

对于此题来说,我们把所有点权扔到 01-Trie 上。每个连通块的最小出边肯定通向与它当中某个点最高位尽可能多地相同的一个点。于是对于 01-Trie 上一个有两个儿子的节点,找到左边和右边的节点中异或和最小的一对连起来。这个不太好维护,因此用启发式合并,枚举 size 较小的一边中的数,拿到另一边来询问。

这个过程也可以用 Kruskal 来理解:当我们在 Trie 上自底往上合并连通块时,我们也就在不断允许加入边权较大边。

代码:

/**********
Author: WLBKR5
Problem: codeforces 888G
Name: Xor-MST
Source: codeforces
Algorithm: 01-Trie, Boruvka 
Date: 2020/06/20
Statue: accepted
Submission: codeforces.com/contest/888/submission/84385568
**********/
#include
using namespace std;
int getint(){
	int ans=0,f=1;
	char c=getchar();
	while(c<'0'||c>'9'){
		if(c=='-')f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9'){
		ans=ans*10+c-'0';
		c=getchar();
	} 
	return ans*f;
}
const int N=2e5+10,L=31,M=N*L;
int ch[M][2],sz[M],cnt=1;
void add(int x){
	int u=1;
	for(int i=L-1;i>=0;--i){
		int t=(x>>i)&1;
		sz[u]++;
		if(!ch[u][t])ch[u][t]=++cnt;
		u=ch[u][t];
	}
	sz[u]++;
}
int query(int x,int val,int l){
	int ans=0;
	for(int i=l;i>=0;--i){
		int t=(val>>i)&1;
		if(ch[x][t])x=ch[x][t];
		else x=ch[x][t^1],ans|=(1<<i);
	}
	return ans;
}
vector<int>v;
long long ans=0;
void solve(int u,int val,int l){
	//cerr<<"solve "<
	if(l<0){
		for(int i=0;i<sz[u];i++)v.push_back(val);
		//cerr<<">> "<
		return;
	}
	if(!ch[u][0]){
		solve(ch[u][1],val|(1<<l),l-1);
		return;
	}
	if(!ch[u][1]){
		solve(ch[u][0],val,l-1);
		return;
	}
	int c1=ch[u][0],c2=ch[u][1];
	if(sz[c1]>sz[c2])swap(c1,c2);
	solve(c2,val|((c2==ch[u][1])<<l),l-1);
	int qaq=v.size();
	solve(c1,val|((c1==ch[u][1])<<l),l-1);
	int mn=0x7f7f7f7f;
	for(int i=qaq;i<v.size();i++){
		mn=min(mn,(1<<l)|query(c2,v[i],l-1));
	}
	ans+=mn;
	return;
}

int a[N];
int main(){
	int n=getint();
	//add(0);
	for(int i=1;i<=n;i++){
		a[i]=getint();
		add(a[i]);
		//cerr<<"add "<
	}
	solve(1,0,L-1);
	cout<<ans;
	return 0;
}

你可能感兴趣的:(题解,#,来源-codeforces,#,图论-Boruvka)