[APIO2013]道路费用

Description

给定一张无向连通图,存在边权 c c 与点权 a a

加入 K K 条特殊边,构造出这些边的边权使得存在一棵最小生成树使得所有点到 1 1 号点的距离(只考虑特殊边的长度)乘以该点点权的乘积最大。

n1000000,m300000,k20 n ⩽ 1000000 , m ⩽ 300000 , k ⩽ 20

Solution

首先强制选择 k k 条特殊边,然后做最小生成树。
这次选进 MST M S T 的边一定会选进最后的 MST M S T 。所以先用这些边进行缩点,这样就只剩 k k 条特殊边和至多 k2 k 2 条特殊边。
直接 O(k2) O ( k 2 ) 枚举选择哪些特殊边,然后用 k2 k 2 条非特殊边使树联通并求出每条特殊边的边权。
时间复杂度 O(mlogm+2k×k2) O ( m l o g m + 2 k × k 2 )

#include 
using namespace std;

typedef long long lint;
const int maxn = 100005, maxm = 300005;
const int Inf = 1 << 30;

int n, m, cnt, tot, tp, root, c[35], fa[2][maxn], f[maxn], d[maxn];
lint siz[maxn], val[maxn];

struct node {
    int u, v, w;
    bool operator < (const node &a) const {
        return w < a.w;
    }
}a[maxm], b[35];

inline int gi()
{
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    int sum = 0;
    while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
    return sum;
}

int getfa(int k, int x) 
{
    if (fa[k][x] == x) return x;
    else return fa[k][x] = getfa(k, fa[k][x]);
}

struct edge {
    int to, next;
}e[maxn * 2];
int h[maxn], Tot;

inline void add(int u, int v)
{
    e[++Tot] = (edge) {v, h[u]}; h[u] = Tot;
    e[++Tot] = (edge) {u, h[v]}; h[v] = Tot;
}

void dfs(int u)
{
    siz[u] = val[u];
    for (int i = h[u], v; v = e[i].to, i; i = e[i].next) 
        if (v != fa[1][u]) {
            fa[1][v] = u; d[v] = d[u] + 1; dfs(v);
            siz[u] += siz[v];
        }
}

int main()
{
    n = gi(); m = gi(); cnt = gi();
    for (int i = 1; i <= m; ++i) 
        a[i] = (node) {gi(), gi(), gi()};
    sort(a + 1, a + m + 1);

    for (int i = 1; i <= n; ++i) fa[0][i] = fa[1][i] = i;

    for (int u, v, i = 1; i <= cnt; ++i) {
        b[i].u = gi(); b[i].v = gi(); 
        u = getfa(0, b[i].u), v = getfa(0, b[i].v);
        if (u != v) fa[0][u] = v;
    }

    for (int u, v, i = 1; i <= m; ++i) {
        u  = getfa(0, a[i].u); v = getfa(0, a[i].v);
        if (u != v) {fa[0][u] = v; fa[1][getfa(1, a[i].u)] = getfa(1, a[i].v);}
    }

    root = getfa(1, 1);

    for (int i = 1; i <= n; ++i) val[getfa(1, i)] += gi();
    for (int i = 1; i <= n; ++i) if (getfa(1, i) == i) c[++c[0]] = i;
    for (int i = 1; i <= cnt; ++i) 
        b[i].u = getfa(1, b[i].u), b[i].v = getfa(1, b[i].v);
    for (int i = 1; i <= m; ++i) 
        a[i].u = getfa(1, a[i].u), a[i].v = getfa(1, a[i].v);
    for (int u, v, i = 1; i <= m; ++i) {
        u = getfa(1, a[i].u), v = getfa(1, a[i].v);
        if (u != v) fa[1][u] = v, a[++tp] = a[i];
    }

    lint ans = 0;
    for (int k = 0; k < (1 << cnt); ++k) {
        Tot = 0;
        for (int u, i = 1; i <= c[0]; ++i) {
            u = c[i]; h[u] = fa[1][u] = 0;
            fa[0][u] = u; f[u] = Inf;
        }
        bool flag = true;
        for (int u, v, i = 1; i <= cnt; ++i)
            if (k & (1 << (i - 1))) {
                u = getfa(0, b[i].u); v = getfa(0, b[i].v);
                if (u == v) {flag = false; break;} fa[0][u] = v;
                add(b[i].u, b[i].v);
            }
        if (!flag) continue;
        for (int u, v, i = 1; i <= cnt; ++i) {
            u = getfa(0, a[i].u); v = getfa(0, a[i].v);
            if (u != v) {fa[0][u] = v; add(a[i].u, a[i].v);}
        }
        dfs(root);
        for (int u, v, i = 1; i <= cnt; ++i) {
            u = a[i].u; v = a[i].v;
            if (d[u] < d[v]) swap(u, v);
            while (d[u] != d[v]) f[u] = min(f[u], a[i].w), u = fa[1][u];
            while (u != v) {
                f[u] = min(f[u], a[i].w); f[v] = min(f[v], a[i].w);
                u = fa[1][u]; v = fa[1][v];
            }
        }
        lint tmp = 0;
        for (int u, v, i = 1; i <= cnt; ++i)
            if (k & (1 << (i - 1))) {
                u = b[i].u; v = b[i].v;
                if (d[u] < d[v]) swap(u, v);
                tmp += siz[u] * f[u];
            }
        ans = max(ans, tmp);
    }

    printf("%lld\n", ans);
    return 0;
}

你可能感兴趣的:(图论——MST)