二分图最大匹配-HK算法的简单理解和实现

求一个二分图的最大匹配,朴实的匈牙利算法的复杂度为O (VE),优点是代码量很少,而且顶点数目在1000以内的时候表现不错。

bool dfs(int u,int tot){
	for(int i = head[u];i != -1; i = e[i].nxt){
		int v= e[i].v;
		if(vis[v] == tot) continue;
		vis[v] = tot;
		if(con[v] == -1 || dfs(con[v],tot)) {
			con[u] = v;con[v] = u;return true;
		}
	}return false;
}

但是当这个复杂度不足以解决问题的时候,比如说遇到这题的时候
HDU2389
就需要更优秀的算法来求最大匹配,于是就强制学习了一下HK算法

网上看了一些资料之后大多是描述算法步骤而证明比较简略,而在理解步骤之后发现HK算法可以看成是在匈牙利算法之上的优化,在原来的基础上加了一个bfs,dfs函数稍微改动一点。

问:这个bfs是干嘛的?
预先找到多条路径最短的增广路,然后匈牙利算法就沿着bfs找的路径去找增广路。

bfs具体步骤:
设二分图为X和Y两部分

  1. 把X中所有没匹配的点加入队列
  2. 每次出来一个点u,对于它连的Y中的每个点v,如果v没访问过且没匹配过,找到增广路,否则把v的匹配点压入队列

bfs的时候顺便记录一下每个点在bfs中的层次,看代码比较容易理解

bool bfs(){
	memset(dep,0,sizeof dep);//  初始化层次
	queue q;while(q.size()) q.pop();
	for(int i = 1;i <= n;++i) if(con[i] == -1) q.push(i);
	bool flag = false;//标记有没有找到增广路
	while(q.size()){
		int u = q.front();q.pop();
		for(int i = head[u];i != -1; i = e[i].nxt){
			int v = e[i].v;
			if(!dep[v]){//没在bfs中访问过,防止增广路相交
				dep[v] = dep[u] + 1;
				if(con[v] == -1) flag = true;//v没匹配过,找到增广路
				else dep[con[v]] = dep[v] + 1,q.push(con[v]);//v匹配过,把它的匹配点压入队列
			}
		}
	}
	return flag;
}

略加修改的dfs找增广路过程:

bool dfs(int u){
	for(int i = head[u];i != -1; i = e[i].nxt){
		int v= e[i].v;
		if(dep[v] != dep[u] + 1) continue;//保证v是u在bfs中的下一个点
		dep[v] = 0;
		if(con[v] == -1 || dfs(con[v])) {
			con[u] = v;con[v] = u;return true;
		}
	}return false;
}

练习:
HDU2389
AC代码:

#include
#include
#include
#include
#include
#define ll long long
using namespace std;
typedef pair P;
const int maxn = 6050;//存的点最多是n+m 
int vis[maxn];
struct node{
	int v,nxt;
}e[maxn*1500];
int cnt = 0;
int head[maxn];
int con[maxn];
ll x[maxn],y[maxn],s[maxn];
ll t;
void add(int u,int v){
	e[cnt].v = v;
	e[cnt].nxt = head[u];
	head[u] = cnt++;
}
int n,m,p;
void init(){
	cnt = 0;
	memset(head,-1,sizeof head);
	memset(vis,0,sizeof vis);
	memset(con,-1,sizeof con);
	cin>>t>>n;
	for(int i = 1;i <= n;++i) scanf("%lld%lld%lld",&x[i],&y[i],&s[i]);
	cin>>m;
	for(int i = 1;i <= m;++i) scanf("%lld%lld",&x[i+n],&y[i+n]);
	for(int i = 1;i <= n;++i)for(int j = 1;j <= m;++j){
		int u = i;
		int v = n+j;
		ll w = (x[u] - x[v])*(x[u] - x[v]) + (y[u] - y[v])*(y[u] - y[v]);
		ll temp = t*s[u];
		temp*=temp;
		if(temp >= w) add(u,v);
	}
}
int dep[maxn];
bool bfs(){
	memset(dep,0,sizeof dep);
	queue q;while(q.size()) q.pop();
	for(int i = 1;i <= n;++i) if(con[i] == -1) q.push(i);
	bool flag = false;
	while(q.size()){
		int u = q.front();q.pop();
		for(int i = head[u];i != -1; i = e[i].nxt){
			int v = e[i].v;
			if(!dep[v]){
				dep[v] = dep[u] + 1;
				if(con[v] == -1) flag = true;
				else dep[con[v]] = dep[v] + 1,q.push(con[v]);
			}
		}
	}
	return flag;
}
bool dfs(int u){
	for(int i = head[u];i != -1; i = e[i].nxt){
		int v= e[i].v;
		if(dep[v] != dep[u] + 1) continue;
		dep[v] = 0;
		if(con[v] == -1 || dfs(con[v])) {
			con[u] = v;con[v] = u;return true;
		}
	}return false;
}
int ca = 0;
void sol(){
	int ans = 0;
	while(bfs()){
		for(int i = 1;i <= n;++i) if(con[i] == -1 && dfs(i)) ans++;
	}
	printf("Scenario #%d:\n%d\n\n",++ca,ans);
}
int main(){
	int T;cin>>T;
	while(T--){
		init();sol();
	}
} 

你可能感兴趣的:(图论)