51NOD 1601 完全图的最小生成树计数 Trie

更好的浏览体验 Press Here

Problem

传送门 >ω<

题目大意:
n n 个点,每个点权值为 a[i] a [ i ] ,两个点连边费用为 a[i]  xor  a[j] a [ i ]     x o r     a [ j ] ,问 最小生成树的边权和 and a n d 方案数

Solution

最小生成树么,一般使用 Kruskal 算法
但是在这里,由于边数达到了 1e10 1 e 10 的级别,显然是不能直接排序的

但是只要抓住 Kruskal 的精髓:边权从低到高合并
那么考虑一下如何找出边权最低的点对:按照二进制位从低到高合并

每个点的二进制状态我们可以用一颗 Trie 来维护,显然其 LCA 越深其异或值越小
所以将 Trie 从低到高合并的过程是这样的

Kruskal(T):
    T 右子树非空 Kruskal(T 右子树)
    T 左子树非空 Kruskal(T 左子树)
    //这样就形成了两个联通块,只需要在两个块内选择一条边相连即可
    若左右子树均非空
        在左右子树中找出异或值最小点对
        计算方案数
    最小生成树 边权+ 方案数*

找异或值最小点对过程如下

cal(x , y):
    x 的左子树非空 && y 的左子树非空 cal(x 的左子树 , y 的左子树)
    x 的右子树非空 && y 的右子树非空 cal(x 的右子树 , y 的右子树)
    若没有相同的子树(左左/右右)
    x 的左子树非空 cal(x 的左子树 , y 的右子树)
    x 的右子树非空 cal(x 的右子树 , y 的左子树)
    返回最小异或值,方案数

这样利用其左右子树独立的性质就可以很简单的解决这道题

分析

这个题目的复杂度很多人认为就是 O(nlog2n) O ( n log 2 ⁡ n )
但是真的是这样么

一棵满二叉树能够使得每次计算最小点对时访问到所有儿子节点

满二叉树一共有 logn log ⁡ n 层,这是由节点个数所限制的,而链的长度为 logmax(a[i]) log ⁡ m a x ( a [ i ] )

所以当满二叉树的所有叶子节点下方都挂着一条长为 logmax(a[i])logn log ⁡ m a x ( a [ i ] ) − log ⁡ n 的链时,达到最大复杂度

这时,满二叉树的每个节点被访问 dep[x] d e p [ x ] 次,链上每个节点被访问 logn log ⁡ n

总复杂度为 O(nlogn(loga[i]n)+nlogn) O ( n log ⁡ n ( log ⁡ a [ i ] n ) + n log ⁡ n )

代码

#include 
using namespace std;
const int N = 1000010;
const int mod = 1000000007;
typedef long long ll;
int s[N << 5][2];
int c[N << 5];
int n , cnt = 1 , m1 , c1;
long long ans1 = 1 , ans2;
int read() {
    int ans = 0 , flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = -1; ch = getchar();}
    while(ch <= '9' && ch >= '0') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
int qpow(int a , int b) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}
void insert(int x) {
    int now = 1;
    for(int i = 29 ; ~ i ; -- i) {
        if(!s[now][(x >> i) & 1]) s[now][(x >> i) & 1] = ++ cnt;
        now = s[now][(x >> i) & 1];
    }
    ++ c[now];
}
void get_min(int x , int y , int d , int v = 0) {
    if(d < 0) {
        if(v < m1) {m1 = v; c1 = 1ll * c[x] * c[y] % mod;}
        else if(v == m1) {c1 = (1ll * c1 + 1ll * c[x] * c[y] % mod) % mod;}
    }
    if(s[x][0] && s[y][0]) {
        get_min(s[x][0] , s[y][0] , d - 1 , v);
        if(s[x][1] && s[y][1]) get_min(s[x][1] , s[y][1] , d - 1 , v);
    }
    else if(s[x][1] && s[y][1]) get_min(s[x][1] , s[y][1] , d - 1 , v);
    else {
        if(s[x][0]) get_min(s[x][0] , s[y][1] , d - 1 , v + (1 << d));
        if(s[x][1]) get_min(s[x][1] , s[y][0] , d - 1 , v + (1 << d));
    }
}
void solve(int x , int d) {
    if(d < 0) {
        if(c[x] > 2) ans1 = ans1 * qpow(c[x] , c[x] - 2) % mod;
        return;
    }
    if(!s[x][0]) solve(s[x][1] , d - 1);
    else if(!s[x][1]) solve(s[x][0] , d - 1);
    else {
        solve(s[x][0] , d - 1);
        solve(s[x][1] , d - 1);
        m1 = 1 << 30 , c1 = 0;
        get_min(s[x][0] , s[x][1] , d - 1);
        ans2 = ans2 + (1ll << d) + m1;
        ans1 = ans1 * c1 % mod;
    }
}
int main() {
    n = read();
    for(int i = 1 ; i <= n ; ++ i) insert(read());
    solve(1 , 29);
    printf("%lld\n%lld\n", ans2 , ans1);
    return 0;
}

你可能感兴趣的:(【字符串】Trie)