You are given a complete undirected graph with n vertices. A number ai is assigned to each vertex, and the weight of an edge between vertices i and j is equal to ai xor aj.
Calculate the weight of the minimum spanning tree in this graph.
The first line contains n (1 ≤ n ≤ 200000) — the number of vertices in the graph.
The second line contains n integers a1, a2, ..., an (0 ≤ ai < 230) — the numbers assigned to the vertices.
Print one number — the weight of the minimum spanning tree in the graph.
5 1 2 3 4 5
8
4 1 2 3 4
8
题解:
这个题用的算法比较古老偏僻,反正在这之前我是没有听说过的。。。。。
1、Sollin算法介绍
Sollin(Boruvka)算法。
原理大概是这样的:刚开始把每个点看成是一个联通分量,然后同时对所有的联通分量进行扩展,这样的话,每次至少有一半数量的联通分量被合并。
合并的时候是这样进行操作的,首先拿出一个联通分量,然后从这个联通分量向其他的联通分量求一个最小边,然后把最小边两个端点相连的联通分量合并,再去枚举其他的联通分量,保证每次迭代的所有联通分量都被考虑过。
我们只需要迭代logn次就可以了。
2、Sollin算法在本题中的应用:
考虑到边是xor运算得到的,这是套路之一,我们首先建立一个0-1的trie树。
然后把所有的点都加进去。
每次遍历一个联通分量的时候,我们就把这个联通分量从Trie里面删除掉,然后枚举这个联通分量里面的点,对于这个点,在Trie里面找xor最小的点。
然后合并这两个联通分量就好了。
联通分量使用并查集来维护。
3、细节:
注意,这里不能用vector来存放联通分量,否则会超内存的。
正确的方法应该是:
把点按照他所属的联通分量进行排序,这样的话,属于同一个联通分量的点都在连续的一个区段里面,处理起来非常方便。
代码:
#include
#define convert(s,i) ((s>>i)&1)
using namespace std;
typedef pair P;
const int inf = 2e9;
const int maxn = 200007;
struct Trie{
int frq,nxt[2];
}pool[maxn*31];
int cnt;
int n;
void insert(int s){
int cur = 0;
for(int i = 30;i >= 0;--i){
int &pos = pool[cur].nxt[convert(s,i)];
if(!pos) pos = ++cnt;
cur = pos;
pool[cur].frq++;
}
}
int findxor(int s){
int cur = 0,ans = s;
for(int i = 30;i >= 0;--i){
int pos = pool[cur].nxt[convert(s,i)];
if(!(pos && pool[pos].frq)) pos = pool[cur].nxt[1^convert(s,i)],ans ^= (1<= 0;--i){
int pos = pool[cur].nxt[convert(s,i)];
cur = pos;
pool[cur].frq--;
}
}
int a[maxn],parent[maxn],used[maxn];
int find(int x){
return x == parent[x]?x:parent[x] = find(parent[x]);
}
int join(int x,int y){
int px = find(x);
int py = find(y);
if(px == py) return 0;
parent[py] = px;
return 1;
}
bool check(){
int f = 0;
for(int i = 1;i <= n;++i) f += parent[i] == i;
return f == 1;
}
long long res = 0;
P ps[maxn];
int main(){
cnt = 0;
cin>>n;
for(int i = 1;i <= n;++i) parent[i] = i;
memset(pool,0,sizeof(pool));
for(int i = 1;i <= n;++i) scanf("%d",&a[i]);
sort(a+1,a+1+n);n = unique(a+1,a+1+n) - (a+1);
for(int i = 1;i <= n;++i) insert(a[i]);
while(!check()){
memset(used,0,sizeof(used));
for(int i = 1;i <= n;++i) ps[i] = make_pair(find(i),i);
sort(ps+1,ps+1+n);
int pre = ps[1].first,last = 1;
for(int i = 1;i <= n;++i){
int u = ps[i].second;
if(!used[pre] && ps[i].first == pre) del(a[u]);
if(ps[i+1].first != pre){
if(used[find(u)]) {
for(int j = last;j <= i;j++) insert(a[ps[j].second]);
last = i+1;pre = ps[last].first;
continue;
}
used[pre] = 1;
int mi = inf,cv;
for(int j = last;j <= i;++j) {
int v = findxor(a[ps[j].second]);
if((v^a[ps[j].second]) < mi) mi = v^a[ps[j].second],cv = v;
}
res += mi;
for(int j = last;j <= i;++j) insert(a[ps[j].second]);
int pj = lower_bound(a+1,a+1+n,cv)-a;
pj = find(pj);
pre = find(u);
if(pre > pj) swap(pre,pj);
join(pre,pj);
pre = ps[i+1].first,last = i+1;
}
}
}
cout<