更好的浏览体验 Press Here
传送门 >ω<
题目大意:
有 n n 个点,每个点权值为 a[i] a [ i ] ,两个点连边费用为 a[i] xor a[j] a [ i ] x o r a [ j ] ,问 最小生成树的边权和 and a n d 方案数
最小生成树么,一般使用 Kruskal 算法
但是在这里,由于边数达到了 1e10 1 e 10 的级别,显然是不能直接排序的
但是只要抓住 Kruskal 的精髓:边权从低到高合并
那么考虑一下如何找出边权最低的点对:按照二进制位从低到高合并
每个点的二进制状态我们可以用一颗 Trie 来维护,显然其 LCA 越深其异或值越小
所以将 Trie 从低到高合并的过程是这样的
Kruskal(T):
T 右子树非空 Kruskal(T 右子树)
T 左子树非空 Kruskal(T 左子树)
//这样就形成了两个联通块,只需要在两个块内选择一条边相连即可
若左右子树均非空
在左右子树中找出异或值最小点对
计算方案数
最小生成树 边权+ 方案数*
找异或值最小点对过程如下
cal(x , y):
x 的左子树非空 && y 的左子树非空 cal(x 的左子树 , y 的左子树)
x 的右子树非空 && y 的右子树非空 cal(x 的右子树 , y 的右子树)
若没有相同的子树(左左/右右)
x 的左子树非空 cal(x 的左子树 , y 的右子树)
x 的右子树非空 cal(x 的右子树 , y 的左子树)
返回最小异或值,方案数
这样利用其左右子树独立的性质就可以很简单的解决这道题
这个题目的复杂度很多人认为就是 O(nlog2n) O ( n log 2 n )
但是真的是这样么
一棵满二叉树能够使得每次计算最小点对时访问到所有儿子节点
满二叉树一共有 logn log n 层,这是由节点个数所限制的,而链的长度为 logmax(a[i]) log m a x ( a [ i ] )
所以当满二叉树的所有叶子节点下方都挂着一条长为 logmax(a[i])−logn log m a x ( a [ i ] ) − log n 的链时,达到最大复杂度
这时,满二叉树的每个节点被访问 dep[x] d e p [ x ] 次,链上每个节点被访问 logn log n 次
总复杂度为 O(nlogn(loga[i]n)+nlogn) O ( n log n ( log a [ i ] n ) + n log n )
#include
using namespace std;
const int N = 1000010;
const int mod = 1000000007;
typedef long long ll;
int s[N << 5][2];
int c[N << 5];
int n , cnt = 1 , m1 , c1;
long long ans1 = 1 , ans2;
int read() {
int ans = 0 , flag = 1;
char ch = getchar();
while(ch > '9' || ch < '0') {if(ch == '-') flag = -1; ch = getchar();}
while(ch <= '9' && ch >= '0') {ans = ans * 10 + ch - '0'; ch = getchar();}
return ans * flag;
}
int qpow(int a , int b) {
int ans = 1;
while(b) {
if(b & 1) ans = 1ll * ans * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
void insert(int x) {
int now = 1;
for(int i = 29 ; ~ i ; -- i) {
if(!s[now][(x >> i) & 1]) s[now][(x >> i) & 1] = ++ cnt;
now = s[now][(x >> i) & 1];
}
++ c[now];
}
void get_min(int x , int y , int d , int v = 0) {
if(d < 0) {
if(v < m1) {m1 = v; c1 = 1ll * c[x] * c[y] % mod;}
else if(v == m1) {c1 = (1ll * c1 + 1ll * c[x] * c[y] % mod) % mod;}
}
if(s[x][0] && s[y][0]) {
get_min(s[x][0] , s[y][0] , d - 1 , v);
if(s[x][1] && s[y][1]) get_min(s[x][1] , s[y][1] , d - 1 , v);
}
else if(s[x][1] && s[y][1]) get_min(s[x][1] , s[y][1] , d - 1 , v);
else {
if(s[x][0]) get_min(s[x][0] , s[y][1] , d - 1 , v + (1 << d));
if(s[x][1]) get_min(s[x][1] , s[y][0] , d - 1 , v + (1 << d));
}
}
void solve(int x , int d) {
if(d < 0) {
if(c[x] > 2) ans1 = ans1 * qpow(c[x] , c[x] - 2) % mod;
return;
}
if(!s[x][0]) solve(s[x][1] , d - 1);
else if(!s[x][1]) solve(s[x][0] , d - 1);
else {
solve(s[x][0] , d - 1);
solve(s[x][1] , d - 1);
m1 = 1 << 30 , c1 = 0;
get_min(s[x][0] , s[x][1] , d - 1);
ans2 = ans2 + (1ll << d) + m1;
ans1 = ans1 * c1 % mod;
}
}
int main() {
n = read();
for(int i = 1 ; i <= n ; ++ i) insert(read());
solve(1 , 29);
printf("%lld\n%lld\n", ans2 , ans1);
return 0;
}