先用floyd预处理出最短路,然后直接km匹配就可以了。。。注意预处理的时候由于只能到达港口一次,所以floyd要先枚举中间点,不能最后枚举中间点。。。
#include <iostream> #include <queue> #include <stack> #include <map> #include <set> #include <bitset> #include <cstdio> #include <algorithm> #include <cstring> #include <climits> #include <cstdlib> #include <cmath> #include <time.h> #define maxn 305 #define maxm 400005 #define eps 1e-10 #define mod 1000000007 #define INF 999999999 #define PI (acos(-1.0)) #define lowbit(x) (x&(-x)) #define mp make_pair #define ls o<<1 #define rs o<<1 | 1 #define lson o<<1, L, mid #define rson o<<1 | 1, mid+1, R #define pii pair<int, int> #pragma comment(linker, "/STACK:16777216") typedef long long LL; typedef unsigned long long ULL; //typedef int LL; using namespace std; LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;} LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;} // head int G[maxn][maxn]; int g[maxn][maxn]; int linker[maxn]; int slack[maxn]; bool visx[maxn]; bool visy[maxn]; int lx[maxn]; int ly[maxn]; int a[maxn]; int n, m, o, m1, m2, nx, ny; bool dfs(int x) { visx[x] = true; for(int y = 0; y < ny; y++) { if(visy[y]) continue; int tmp = lx[x] + ly[y] - G[x][y]; if(!tmp) { visy[y] = true; if(linker[y] == -1 || dfs(linker[y])) { linker[y] = x; return true; } } else slack[y] = min(slack[y], tmp); } return false; } int km() { memset(linker, -1, sizeof linker); memset(ly, 0, sizeof ly); for(int i = 0; i < nx; i++) { lx[i] = -INF; for(int j = 0; j < ny; j++) lx[i] = max(lx[i], G[i][j]); } for(int x = 0; x < nx; x++) { for(int i = 0; i < ny; i++) slack[i] = INF; while(true) { memset(visx, 0, sizeof visx); memset(visy, 0, sizeof visy); if(dfs(x)) break; int d = INF; for(int i = 0; i < ny; i++) if(!visy[i] && d > slack[i]) d = slack[i]; for(int i = 0; i < nx; i++) if(visx[i]) lx[i] -= d; for(int i = 0; i < ny; i++) if(visy[i]) ly[i] += d; else slack[i] -= d; } } int res = 0; for(int i = 0; i < ny; i++) res += G[linker[i]][i]; return res; } void read() { scanf("%d%d%d", &m, &m1, &m2); o = n + m; for(int i = 0; i < o; i++) for(int j = 0; j < o; j++) g[i][j] = INF; for(int i = 0; i < n; i++) scanf("%d", &a[i]), a[i]--, a[i] += n; int u, v, w; while(m1--) { scanf("%d%d%d", &u, &v, &w); u--, v--; u += n; v += n; g[u][v] = min(g[u][v], w); g[v][u] = min(g[v][u], w); } while(m2--) { scanf("%d%d%d", &u, &v, &w); u--, v--; v += n; g[u][v] = min(g[u][v], w); g[v][u] = min(g[v][u], w); } } void work() { for(int k = n; k < o; k++) for(int i = 0; i < o; i++) for(int j = 0; j < o; j++) g[i][j] = min(g[i][j], g[i][k] + g[k][j]); for(int i = 0; i < n; i++) for(int j = 0; j < n; j++) if(g[a[i]][j] == INF) G[i][j] = 0; else G[i][j] = -g[a[i]][j]; nx = ny = n; printf("%d\n", -km()); } int main() { while(scanf("%d", &n)!=EOF) { read(); work(); } return 0; }