匈牙利算法板子(用邻接表写可以做到O(n*m)):
const int maxn = 510;
int cnt_x,cnt_y;
int G[maxn][maxn];
int link[maxn];
bool vis[maxn];
bool find(int u){
for(int v=1;v<=cnt_y;v++){
if(G[u][v]&&!vis[v]){
vis[v]=true;
if(link[v]==-1||find(link[v])){
link[v]=u;
return true;
}
}
}
return false;
}
int Hungary(){
int ans=0;
memset(link,-1,sizeof(link));
for(int u=1;u<=cnt_x;u++){
memset(vis,0,sizeof(vis));
if(find(u)) ans++;
}
return ans;
}
KM算法の O(n^3)的板子:
#include
#define ll long long
#define endl '\n'
#define mem(a) memset(a,0,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;
const int maxn=1e3+7;
int wx[maxn],wy[maxn];
int cx[maxn],cy[maxn];
int visx[maxn],visy[maxn];
int cntx,cnty;
int Map[maxn][maxn];
int slack[maxn];
bool find(int u){
visx[u]=1;
for(int v=1;v<=cnty;v++){
if(!visy[v]&&Map[u][v]!=INF){
int t=wx[u]+wy[v]-Map[u][v];
if(t==0){
visy[v]=1;
if(cy[v]==-1||find(cy[v])){
cx[u]=v;
cy[v]=u;
return true;
}
}
else if(t>0){
slack[v]=min(slack[v],t);
}
}
}
return false;
}
int KM(){
memset(cx,-1,sizeof(cx));
memset(cy,-1,sizeof(cy));
memset(wx,0,sizeof(wx));
memset(wy,0,sizeof(wy));
for(int i=1;i<=cntx;i++){
for(int j=1;j<=cnty;j++){
if(Map[i][j]==INF) continue;
wx[i]=max(wx[i],Map[i][j]);
}
}
for(int i=1;i<=cntx;i++){
memset(slack,INF,sizeof(slack));
while(1){
memset(visx,0,sizeof(visx));
memset(visy,0,sizeof(visy));
if(find(i)) break;
int d=INF;
for(int j=1;j<=cnty;j++){
if(!visy[j]&&d>slack[j]) d=slack[j];
}
for(int j=1;j<=cntx;j++){
if(visx[j]) wx[j]-=d;
}
for(int j=1;j<=cnty;j++){
if(visy[j]) wy[j]+=d;
else slack[j]-=d;
}
}
}
int ans=0;
for(int i=1;i<=cntx;i++){
if(cx[i]!=-1) ans+=Map[i][cx[i]];
}
return ans;
}
int main(){
while(cin>>cntx>>cnty){
for(int i=1;i<=cntx;i++){
for(int j=1;j<=cnty;j++){
cin>>Map[i][j];
}
}
cout<<KM()<<endl;
}
return 0;
}
O(n^4)的板子,看上去易懂,但是不是很推荐,适合入门:
ps(n左m右)
#include
#define ll long long
#define endl '\n'
#define mem(a) memset(a,0,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;
const int maxn=1e3+7;
int w[maxn][maxn];
int line[maxn],usex[maxn],usey[maxn],cx[maxn],cy[maxn];
int ans,n,m;
bool find(int x){
usex[x]=1;
for(int i=1;i<=m;i++){
if((usey[i]==0)&&(cx[x]+cy[i]==w[x][i])){
usey[i]=1;
if(line[i]==0||find(line[i])){
line[i]=x;
return true;
}
}
}
return false;
}
int KM(){
for(int i=1;i<=n;i++){
while(true){
int d=INF;
memset(usex,0,sizeof(usex));
memset(usey,0,sizeof(usey));
if(find(i)) break;
for(int j=1;j<=n;j++){
if(usex[j]){
for(int k=1;k<=m;k++){
if(!usey[k]) d=min(d,cx[j]+cy[k]-w[j][k]);
}
}
}
if(d==INF) return -1;
for(int j=1;j<=n;j++){
if(usex[j]) cx[j]-=d;
}
for(int j=1;j<=m;j++){
if(usey[j]) cy[j]+=d;
}
}
}
ans=0;
for(int i=1;i<=m;i++){
ans+=w[line[i]][i];
}
return ans;
}
int main(){
while(cin>>n>>m){
memset(cx,0,sizeof(cx));
memset(cy,0,sizeof(cy));
memset(w,0,sizeof(w));
memset(line,0,sizeof(line));
for(int i=1;i<=n;i++){
int d=0;
for(int j=1;j<=n;j++){
cin>>w[i][j];
d=max(d,w[i][j]);
}
cx[i]=d;
}
cout<<KM()<<endl;
}
return 0;
}
hopcroft-karp算法(O(sqrt(n)m) ):
#include
#define ll long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '\n'
#define mem(a) memset(a,0,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;
const int maxn=505;
int bmap[maxn][maxn];
int cx[maxn],cy[maxn],dx[maxn],dy[maxn];
bool vis[maxn];
int nx,ny,k,dis;
bool searchpath(){
queue<int> q;
dis=INF;
memset(dx,-1,sizeof(dx));
memset(dy,-1,sizeof(dy));
for(int i=1;i<=nx;i++){
if(cx[i]==-1){
q.push(i);
dx[i]=0;
}
}
while(!q.empty()){
int u=q.front();
q.pop();
if(dx[u]>dis) break;
for(int v=1;v<=ny;v++){
if(bmap[u][v]&&dy[v]==-1){
dy[v]=dx[u]+1;
if(cy[v]==-1) dis=dy[v];
else{
dx[cy[v]]=dy[v]+1;
q.push(cy[v]);
}
}
}
}
return dis!=INF;
}
int find(int u){
for(int v=1;v<=ny;v++){
if(!vis[v]&&bmap[u][v]&&dy[v]==dx[u]+1){
vis[v]=1;
if(cy[v]!=-1&&dy[v]==dis) continue;
if(cy[v]==-1||find(cy[v])){
cy[v]=u;cx[u]=v;
return true;
}
}
}
return false;
}
int MaxMatch(){
int res=0;
memset(cx,-1,sizeof(cx));
memset(cy,-1,sizeof(cy));
while(searchpath()){
memset(vis,0,sizeof(vis));
for(int i=1;i<=nx;i++){
if(cx[i]==-1){
res+=find(i);
}
}
}
return res;
}
int main(){
while(cin>>k>>nx>>ny&&k){
memset(bmap,0,sizeof(bmap));
for(int i=1;i<=k;i++){
int u,v;cin>>u>>v;
bmap[u][v]=1;
}
cout<<MaxMatch()<<endl;
}
}