匈牙利算法是解决二分匹配的一个经典算法,昨天学长很详细的讲了一下。也算小有理解,在这里分享一下。
匈牙利算法就是解决二分最优匹配的算法,比如给出hdoj上面这道题为例:http://acm.hdu.edu.cn/showproblem.php?pid=1083
给出n个学生和p个课程。每个课程有n个同学中的几个在学习,要求我们求出能不能让选择出n个同学组成一个委员会,且能够每门课程都有一名代表在该委员会中,这就是一个最优的二分匹配问题。首先是我们怎么建图。拿给出的第一组测试数据来说吧。我们用一个map数组来存这个图,首先初始化为0,第p个课程第n个同学在学的话我们map[p][n]=1。就会得到这样一个图
课程/同学 1 2 3
1 1 1 1
2 1 1 0
3 1 0 0
我们首先第一门课程让第一个同学代表,即link[1]=1,然后第二门课程发现也可以让第一个同学代表,我们找第一门课程可不可以连其他同学来让出第一个同学给第二门课程,发现第一门课程还可以让第二个代表,这样第一门课程让第二个同学代表,即link【1】=2,这样第二们课程就可以让第一个同学代表,即link【2】=1,继续往下走发现第三门课程只能又第一个同学代表,我们看第一个同学在代表第二门课程,我们看第二门课程可不可以让其他同学代表让出第一个同学,发现第二门课程还可以让第二个同学代表,但是第二个同学也是第一门课程占着,我们在找可不可以把第二个同学让给第二门课程。发现可以让第一门课程由第三个同学代理。link【1】=3,link【2】=2,link【3】=1.这样就能够完美的二分匹配,即每门课程都有一名代表在该委员会中,其实就是一个不断让和不断找的过程,实现是用一个递归+深搜实现的,现在附上这道题代码:
/*hdoj1083二分图匹配模板题*/ #include <iostream> #include <cstring> using namespace std; int n,p; int link[305],vis[305],map[305][305]; bool dfs(int x) { for(int i=1;i<=n;i++) { if(map[x][i]==1 && vis[i]==0) { vis[i]=1; if(link[i]==0 || dfs(link[i])) { link[i]=x; return true; } //vis[i]=0; //搞不懂加上这个会超时 } } return false; } int main() { int T,h,m,count,i; cin>>T; while(T--) { cin>>p>>n; memset(map,0,sizeof(map)); for(int i=1;i<=p;i++) { cin>>m; for(int j=1;j<=m;j++) { cin>>h; map[i][h]=1; } } memset(link,0,sizeof(link)); count=0; for(int i=1;i<=p;i++) { memset(vis,0,sizeof(vis)); if(dfs(i)) count++; } if(count==p) cout<<"YES"<<endl; else cout<<"NO"<<endl; } return 0; }
这里还有nyoj上的一道:http://acm.nyist.net/JudgeOnline/problem.php?pid=239
如果还是用上面的方法的话就会超时,这里用邻接表来实现的话会快一点,贴个代码:
#include <stdio.h> #include <vector> #include <cstring> using namespace std; vector<int> v[505]; int vis[505],link[505]; bool getnum(int i) { for(int j=0;j<v[i].size();j++) { if(vis[v[i][j]]==0) //注意这里容器的访问,ij表示v[i]中的第j个元素是 { vis[v[i][j]]=1; if(link[v[i][j]]==0 || getnum(link[v[i][j]])) { link[v[i][j]]=i; return true; } //vis[v[i][j]]=0;//这里不要这个因为这个会重复判断 } } return false; } int main() { int T,n,k,a,b,count; scanf("%d",&T); while(T--) { scanf("%d%d",&n,&k); memset(v,0,sizeof(v)); memset(link,0,sizeof(link)); for(int i=0;i<k;i++) { scanf("%d%d",&a,&b); v[a].push_back(b); } count=0; for(int i=1;i<=n;i++) { memset(vis,0,sizeof(vis));//清零 if(getnum(i)) count++; } printf("%d\n",count); } }