给定一张无向连通图,存在边权 c c 与点权 a a 。
加入 K K 条特殊边,构造出这些边的边权使得存在一棵最小生成树使得所有点到 1 1 号点的距离(只考虑特殊边的长度)乘以该点点权的乘积最大。
n⩽1000000,m⩽300000,k⩽20 n ⩽ 1000000 , m ⩽ 300000 , k ⩽ 20
首先强制选择 k k 条特殊边,然后做最小生成树。
这次选进 MST M S T 的边一定会选进最后的 MST M S T 。所以先用这些边进行缩点,这样就只剩 k k 条特殊边和至多 k2 k 2 条特殊边。
直接 O(k2) O ( k 2 ) 枚举选择哪些特殊边,然后用 k2 k 2 条非特殊边使树联通并求出每条特殊边的边权。
时间复杂度 O(mlogm+2k×k2) O ( m l o g m + 2 k × k 2 )
#include
using namespace std;
typedef long long lint;
const int maxn = 100005, maxm = 300005;
const int Inf = 1 << 30;
int n, m, cnt, tot, tp, root, c[35], fa[2][maxn], f[maxn], d[maxn];
lint siz[maxn], val[maxn];
struct node {
int u, v, w;
bool operator < (const node &a) const {
return w < a.w;
}
}a[maxm], b[35];
inline int gi()
{
char c = getchar();
while (c < '0' || c > '9') c = getchar();
int sum = 0;
while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
return sum;
}
int getfa(int k, int x)
{
if (fa[k][x] == x) return x;
else return fa[k][x] = getfa(k, fa[k][x]);
}
struct edge {
int to, next;
}e[maxn * 2];
int h[maxn], Tot;
inline void add(int u, int v)
{
e[++Tot] = (edge) {v, h[u]}; h[u] = Tot;
e[++Tot] = (edge) {u, h[v]}; h[v] = Tot;
}
void dfs(int u)
{
siz[u] = val[u];
for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
if (v != fa[1][u]) {
fa[1][v] = u; d[v] = d[u] + 1; dfs(v);
siz[u] += siz[v];
}
}
int main()
{
n = gi(); m = gi(); cnt = gi();
for (int i = 1; i <= m; ++i)
a[i] = (node) {gi(), gi(), gi()};
sort(a + 1, a + m + 1);
for (int i = 1; i <= n; ++i) fa[0][i] = fa[1][i] = i;
for (int u, v, i = 1; i <= cnt; ++i) {
b[i].u = gi(); b[i].v = gi();
u = getfa(0, b[i].u), v = getfa(0, b[i].v);
if (u != v) fa[0][u] = v;
}
for (int u, v, i = 1; i <= m; ++i) {
u = getfa(0, a[i].u); v = getfa(0, a[i].v);
if (u != v) {fa[0][u] = v; fa[1][getfa(1, a[i].u)] = getfa(1, a[i].v);}
}
root = getfa(1, 1);
for (int i = 1; i <= n; ++i) val[getfa(1, i)] += gi();
for (int i = 1; i <= n; ++i) if (getfa(1, i) == i) c[++c[0]] = i;
for (int i = 1; i <= cnt; ++i)
b[i].u = getfa(1, b[i].u), b[i].v = getfa(1, b[i].v);
for (int i = 1; i <= m; ++i)
a[i].u = getfa(1, a[i].u), a[i].v = getfa(1, a[i].v);
for (int u, v, i = 1; i <= m; ++i) {
u = getfa(1, a[i].u), v = getfa(1, a[i].v);
if (u != v) fa[1][u] = v, a[++tp] = a[i];
}
lint ans = 0;
for (int k = 0; k < (1 << cnt); ++k) {
Tot = 0;
for (int u, i = 1; i <= c[0]; ++i) {
u = c[i]; h[u] = fa[1][u] = 0;
fa[0][u] = u; f[u] = Inf;
}
bool flag = true;
for (int u, v, i = 1; i <= cnt; ++i)
if (k & (1 << (i - 1))) {
u = getfa(0, b[i].u); v = getfa(0, b[i].v);
if (u == v) {flag = false; break;} fa[0][u] = v;
add(b[i].u, b[i].v);
}
if (!flag) continue;
for (int u, v, i = 1; i <= cnt; ++i) {
u = getfa(0, a[i].u); v = getfa(0, a[i].v);
if (u != v) {fa[0][u] = v; add(a[i].u, a[i].v);}
}
dfs(root);
for (int u, v, i = 1; i <= cnt; ++i) {
u = a[i].u; v = a[i].v;
if (d[u] < d[v]) swap(u, v);
while (d[u] != d[v]) f[u] = min(f[u], a[i].w), u = fa[1][u];
while (u != v) {
f[u] = min(f[u], a[i].w); f[v] = min(f[v], a[i].w);
u = fa[1][u]; v = fa[1][v];
}
}
lint tmp = 0;
for (int u, v, i = 1; i <= cnt; ++i)
if (k & (1 << (i - 1))) {
u = b[i].u; v = b[i].v;
if (d[u] < d[v]) swap(u, v);
tmp += siz[u] * f[u];
}
ans = max(ans, tmp);
}
printf("%lld\n", ans);
return 0;
}