codeforces 888G Xor-MST Sollin算法求最小生成树,0-1异或True

G. Xor-MST
time limit per test
2 seconds
memory limit per test
256 megabytes
input
standard input
output
standard output

You are given a complete undirected graph with n vertices. A number ai is assigned to each vertex, and the weight of an edge between vertices i and j is equal to aixor aj.

Calculate the weight of the minimum spanning tree in this graph.

Input

The first line contains n (1 ≤ n ≤ 200000) — the number of vertices in the graph.

The second line contains n integers a1a2, ..., an (0 ≤ ai < 230) — the numbers assigned to the vertices.

Output

Print one number — the weight of the minimum spanning tree in the graph.

Examples
input
5
1 2 3 4 5
output
8
input
4
1 2 3 4
output
8

题解:

这个题用的算法比较古老偏僻,反正在这之前我是没有听说过的。。。。。

1、Sollin算法介绍

Sollin(Boruvka)算法。

原理大概是这样的:刚开始把每个点看成是一个联通分量,然后同时对所有的联通分量进行扩展,这样的话,每次至少有一半数量的联通分量被合并。

合并的时候是这样进行操作的,首先拿出一个联通分量,然后从这个联通分量向其他的联通分量求一个最小边,然后把最小边两个端点相连的联通分量合并,再去枚举其他的联通分量,保证每次迭代的所有联通分量都被考虑过。

我们只需要迭代logn次就可以了。

2、Sollin算法在本题中的应用:

考虑到边是xor运算得到的,这是套路之一,我们首先建立一个0-1的trie树。

然后把所有的点都加进去。

每次遍历一个联通分量的时候,我们就把这个联通分量从Trie里面删除掉,然后枚举这个联通分量里面的点,对于这个点,在Trie里面找xor最小的点。

然后合并这两个联通分量就好了。

联通分量使用并查集来维护。

3、细节:

注意,这里不能用vector来存放联通分量,否则会超内存的。

正确的方法应该是:

点按照他所属的联通分量进行排序,这样的话,属于同一个联通分量的点都在连续的一个区段里面,处理起来非常方便。

代码:

#include
#define convert(s,i) ((s>>i)&1)
using namespace std;
typedef pair P;
const int inf = 2e9;
const int maxn = 200007;
struct Trie{
	int frq,nxt[2];
}pool[maxn*31];
int cnt;
int n;
void insert(int s){
	int cur = 0;
	for(int i = 30;i >= 0;--i){
		int &pos = pool[cur].nxt[convert(s,i)];
		if(!pos) pos = ++cnt;
		cur = pos;
		pool[cur].frq++;
	}
	
}
int findxor(int s){
	int cur = 0,ans = s;
	for(int i = 30;i >= 0;--i){
		int pos = pool[cur].nxt[convert(s,i)];
		if(!(pos && pool[pos].frq)) pos = pool[cur].nxt[1^convert(s,i)],ans ^= (1<= 0;--i){
		int pos = pool[cur].nxt[convert(s,i)];
		cur = pos;
		pool[cur].frq--;
	}
}
int a[maxn],parent[maxn],used[maxn];
int find(int x){
	return x == parent[x]?x:parent[x] = find(parent[x]);
}
int join(int x,int y){
	int px = find(x);
	int py = find(y);
	if(px == py) return 0;
	parent[py] = px;
	return 1;
}
bool check(){
	int f = 0;
	for(int i = 1;i <= n;++i) f += parent[i] == i;
	return f == 1;
}
long long res = 0;
P ps[maxn];
int main(){
	cnt = 0;
	cin>>n;
	for(int i = 1;i <= n;++i) parent[i] = i;
	memset(pool,0,sizeof(pool));
	for(int i = 1;i <= n;++i)  scanf("%d",&a[i]);
	sort(a+1,a+1+n);n = unique(a+1,a+1+n) - (a+1);
	for(int i = 1;i <= n;++i) insert(a[i]);
	while(!check()){
		memset(used,0,sizeof(used));
		for(int i = 1;i <= n;++i) ps[i] = make_pair(find(i),i);
		sort(ps+1,ps+1+n);
		int pre = ps[1].first,last = 1;
		for(int i = 1;i <= n;++i){
			int u = ps[i].second;
			if(!used[pre] && ps[i].first == pre) del(a[u]);
			if(ps[i+1].first != pre){
				if(used[find(u)]) {
					for(int j = last;j <= i;j++) insert(a[ps[j].second]);
					last = i+1;pre = ps[last].first;
					continue;
				}
				used[pre] = 1;
				int mi = inf,cv;
				for(int j = last;j <= i;++j) {
					int v = findxor(a[ps[j].second]);
					if((v^a[ps[j].second]) < mi) mi = v^a[ps[j].second],cv = v;
				} 
				res += mi;
				for(int j = last;j <= i;++j) insert(a[ps[j].second]);
				int pj = lower_bound(a+1,a+1+n,cv)-a;
				pj = find(pj);
				pre = find(u);
				if(pre > pj) swap(pre,pj);
				join(pre,pj);
				pre = ps[i+1].first,last = i+1;
			}
		}
	}	
	cout<








你可能感兴趣的:(ACM-ICPC训练题解,CODEFORCES训练记录)