前言:
高中时候老师讲这个就听得迷迷糊糊,有一晚花了通宵看KM的Pascal代码,大概知道过程了,后来老师说不是重点,所以忘的差不多了。都知道二分图匹配是个难点,我这周花了些时间研究了一下这两个算法,总结一下
M代表匹配集合
未盖点:不与任何一条属于M的边相连的点
交错轨:属于M的边与不属于M的边交替出现的轨(链)
可增广轨:两端点是未盖点的交错轨
判断M是最大匹配的标准: M中不存在可增广轨
时间复杂度:O(|V||E|)
原理:
寻找M的可增广轨P,P包含2k+1条边,其中k条属于M,k+1条不属于M。修改M为M&P。即 这条轨进行与M进行对称差分运算。
所谓对称差分运算,就是比如X和Y都是集合,X & Y=( X 并 Y ) - ( x 交 Y )
有一个定理是:M & P的边数是 |M|+1,因此对称差分运算扩大了M
实现:
关于这个实现,有DFS和BFS两种方法。先列出DFS的代码,带注释。这段代码来自中山大学的教材
核心部分在dfs(x),来寻找可增广轨。如果找到的话,在Hungarian()中,最大匹配数加一。这是用了刚才提到的定理。大家可以想想初始状态是什么,又是如何变化的
#include <iostream> #define F(i,a,b) for (int i=a;i<=b;i++) #define maxn 1002 using namespace std; bool mk[maxn], map[maxn][maxn]; int match[maxn], n, m; bool dfs(int x) //寻找增广链,true表示找到 { for (int i=1;i<=m;i++) { if ( map[x][i] && !mk[i] ) { mk[i]=true; int t=match[i]; match[i]=x; if (t==0 || dfs(t)) return true; match[i]=t; } } return false; } int hungarian() { int max_=0; F(i,1,n) { memset(mk, false, sizeof(mk)); if ( dfs(i) ) max_++; } return max_; } int main() { freopen ("in.txt","r",stdin); memset(map, false, sizeof(map)); cin >> n >> m; int a, b; while (cin >> a >> b) { map[a][b]=true; } cout << hungarian() << endl; F(i,1,m) if (match[i]!=0) cout << match[i] << " " << i << endl; return 0; } /* http://blog.csdn.net/akof1314/archive/2009/08/07/4421262.aspxiageProblem.png in.txt 5 5 1 1 1 2 2 2 2 3 3 2 3 5 4 3 5 3 5 4 5 5 out.txt 5 剩下是匹配结果 */
第二种方法BFS,来自我的学长cnhawk
核心步骤还是寻找可增广链,过程是:
1.从左的一个未匹配点开始,把所有她相连的点加入队列
2.如果在右边找到一个未匹配点,则找到可增广链
3.如果在右边找到的是一个匹配的点,则看它是从左边哪个点匹配而来的,将那个点出发的所有右边点加入队列
这么说还是不容易明白,看代码吧
//匈牙利算法实现 int Bipartite(bool vc [][MAX],int nv1,int nv2) { //vc[][]为二分图,nv1,nv2分别为左边的点数 int i, j, x, n; //n为最大匹配数 int q[MAX], prev[MAX], qs, qe; //q是BFS用的队列,prev是用来记录交错链的,同时也用来记录右边的点是否被找过 int vm1[MAX], vm2[MAX]; //vm1,vm2分别表示两边的点与另一边的哪个点相匹配 n = 0; for( i = 0; i < nv1; i++ ) vm1[i] = -1; for( i = 0; i < nv2; i++ ) vm2[i] = -1; //初始化所有点为未被匹配的状态 for( i = 0; i < nv1; i++ ) { if(vm1[i] != -1)continue; //对于左边每一个未被匹配的点进行一次BFS找交错链 for( j = 0; j < nv2; j++ ) prev[j] = -2;//表示刚进行过初始化 //每次BFS时初始化右边的点 qs = qe = 0; //初始化BFS的队列 //下面这部分代码从初始的那个点开始,先把它能找的的右边的点放入队列 //★稍微改一下可以适用于用邻接表实现的二分图 for( j = 0; j < nv2; j++ ) if( vc[i][j] ) { prev[j] = -1; q[qe++] = j; } //BFS while( qs < qe ) { x = q[qs]; if( vm2[x] == -1 ) break; //如果找到一个未被匹配的点,则结束,找到了一条交错链 qs++; //下面这部分是扩展结点的代码,★稍微改一下可以适用于用邻接表实现的二分图 for( j = 0; j < nv2; j++ ) if( prev[j] == -2 && vc[vm2[x]][j] ) { //如果该右边点是一个已经被匹配的点,则vm2[x]是与该点相匹配的左边点 //从该左边点出发,寻找其他可以找到的右边点 prev[j] = x; q[qe++] = j; } } if( qs == qe ) continue; //没有找到交错链 //更改交错链上匹配状态 while( prev[x] > -1 ) { vm1[vm2[prev[x]]] = x; vm2[x] = vm2[prev[x]]; x = prev[x]; } vm2[x] = i; vm1[i] = x; //匹配的边数加一 n++; } return n; }
加权图中,权值最大的最大匹配
KM算法:
概念:
f(v) 是每个点的一个值,使得对任意u,v C V,f(u)+f(v)>=w[ eu,v ]
集合H:一个边集,使得H中所有u,v满足f(u)+f(v)=w[ eu,v ]
等价子图:Gf(V, H),标有f函数的G图
理论:
对于f和Gf,如果有一个理想匹配集合Mp,则Mp最优。
对于任意匹配集合M,weight( M )<weight( Mp )
KM算法的实质是扩展Gf,直到找到理想的匹配集合
伪代码
// S是未匹配的顶点集 while (M 不是 Mp) { //F(S)是Gf中与S中顶点相邻的顶点集 if( F(S)==T ) { d = min( f[u]+f[w]-weight[u][w] ); //u in S, w not in T for each v in V { if ( v in S ) f[v]=f[v]-d; else if ( v in T ) f[v]=f[v]-d; } } else // { w = F(S)-T中一个顶点 if ( w未匹配 ) { P是刚找到的增大路径 M = M与P的对称差分运算结果 S是某个未匹配的顶点 T= null } else { S=S+ {M中w的相邻点} T=T+w } } }
最后给一个代码,跟伪代码的思路不是很一样。从网上找的
#include <cstdio> #include <memory.h> #include <algorithm> // 使用其中的 min 函数 using namespace std; const int MAX = 1024; int n; // X 的大小 int weight [MAX] [MAX]; // X 到 Y 的映射(权重) int lx [MAX], ly [MAX]; // 标号 bool sx [MAX], sy [MAX]; // 是否被搜索过 int match [MAX]; // Y(i) 与 X(match [i]) 匹配 // 初始化权重 void init (int size); // 从 X(u) 寻找增广道路,找到则返回 true bool path (int u); // 参数 maxsum 为 true ,返回最大权匹配,否则最小权匹配 int bestmatch (bool maxsum = true); void init (int size) { // 根据实际情况,添加代码以初始化 n = size; for (int i = 0; i < n; i ++) for (int j = 0; j < n; j ++) scanf ("%d", &weight [i] [j]); } bool path (int u) { sx [u] = true; for (int v = 0; v < n; v ++) if (!sy [v] && lx[u] + ly [v] == weight [u] [v]) { sy [v] = true; if (match [v] == -1 || path (match [v])) { match [v] = u; return true; } } return false; } int bestmatch (bool maxsum) { int i, j; if (!maxsum) { for (i = 0; i < n; i ++) for (j = 0; j < n; j ++) weight [i] [j] = -weight [i] [j]; } // 初始化标号 for (i = 0; i < n; i ++) { lx [i] = -0x1FFFFFFF; ly [i] = 0; for (j = 0; j < n; j ++) if (lx [i] < weight [i] [j]) lx [i] = weight [i] [j]; } memset (match, -1, sizeof (match)); for (int u = 0; u < n; u ++) while (1) { memset (sx, 0, sizeof (sx)); memset (sy, 0, sizeof (sy)); if (path (u)) break; // 修改标号 int dx = 0x7FFFFFFF; for (i = 0; i < n; i ++) if (sx [i]) for (j = 0; j < n; j ++) if(!sy [j]) dx = min (lx[i] + ly [j] - weight [i] [j], dx); for (i = 0; i < n; i ++) { if (sx [i]) lx [i] -= dx; if (sy [i]) ly [i] += dx; } } int sum = 0; for (i = 0; i < n; i ++) sum += weight [match [i]] [i]; if (!maxsum) { sum = -sum; for (i = 0; i < n; i ++) for (j = 0; j < n; j ++) weight [i] [j] = -weight [i] [j]; // 如果需要保持 weight [ ] [ ] 原来的值,这里需要将其还原 } return sum; } int main() { freopen ("in.txt", "r", stdin); int n; scanf ("%d", &n); init (n); int cost = bestmatch (true); printf ("%d /n", cost); for (int i = 0; i < n; i ++) { printf ("Y %d -> X %d /n", i, match [i]); } return 0; } /* 5 3 4 6 4 9 6 4 5 3 8 7 5 3 4 2 6 3 2 2 5 8 4 5 4 7 //执行bestmatch (true) ,结果为 29 */ /* 5 7 6 4 6 1 4 6 5 7 2 3 5 7 6 8 4 7 8 8 5 2 6 5 6 3 //执行 bestmatch (false) ,结果为 21 */