二分图的最佳完美匹配--KM算法(DFS寻路+BFS寻路(O(n^3))) + HDU2255入门题

Reference Blog:
原理清晰深刻:https://blog.csdn.net/sixdaycoder/article/details/47720471
较容易于理解:https://www.cnblogs.com/wenruo/p/5264235.html

如果二分图的每条边都有一个权(可以是负数),要求一种完备匹配方案,使得所有匹配边的权和最大,记做最佳完美匹配。(特殊的,当所有边的权为1时,就是最大完备匹配问题)

算法流程:
二分图的最佳完美匹配--KM算法(DFS寻路+BFS寻路(O(n^3))) + HDU2255入门题_第1张图片

dfs寻增广路模板:(只针对随即数据O(n^3),对于极限数据(w[i][j]很大)slack优化作用不显著)

const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];
int linker[AX];
int slack[AX];  
int n ;
bool dfs( int x ){
    visx[x] = true;
    for( int y = 1 ; y <= n ; y ++ ){
        if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
            visy[y] = true;
            if( linker[y] == -1 || dfs( linker[y] ) ){
                linker[y] = x ;
                return true;
            }
        }else if( slack[y] > lx[x] + ly[y] - w[x][y] ){
            slack[y] = lx[x] + ly[y] - w[x][y];
        }
    }
    return false;
}

void KM(){
    memset( linker , -1 , sizeof(linker) );
    memset( ly , 0 , sizeof(ly) );
    for( int i = 1 ; i <= n ; i++ ){
        lx[i] = -INF;
        for( int j = 1 ; j <= n ; j++ ){
            if( lx[i] < w[i][j] ) lx[i] = w[i][j];
        }
    }
    for( int x = 1 ; x <= n ; x++ ){
        for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;
        while(1){
            memset( visx , false , sizeof(visx) );
            memset( visy , false , sizeof(visy) );
            if( dfs(x) ){
                break;
            }else{
                int delta = INF;
                for( int j = 1 ; j <= n ; j++ ){
                    if( !visy[j] && delta > slack[j] ){
                        delta = slack[j];
                    }
                }

                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ) lx[i] -= delta;
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visy[i] ) ly[i] += delta;
                    else slack[i] -= delta;
                }
            }
        }
    }
}

BFS寻路模板(真正O(n^3))

const int AX = 3e2+6;
LL w[AX][AX];
LL lx[AX] , ly[AX];
int linker[AX];
LL slack[AX];
int n ;
bool visy[AX];
int pre[AX];
void bfs( int k ){
    int x , y = 0 , yy = 0 , delta;
    memset( pre , 0 , sizeof(pre) );
    for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;
    linker[y] = k;
    while(1){
        x = linker[y]; delta = INF; visy[y] = true;
        for( int i = 1 ; i <= n ;i++ ){
            if( !visy[i] ){
                if( slack[i] > lx[x] + ly[i] - w[x][i] ){
                    slack[i] = lx[x] + ly[i] - w[x][i];
                    pre[i] = y; 
                }
                if( slack[i] < delta ) delta = slack[i] , yy = i ;
            }
        }
        for( int i = 0 ; i <= n ; i++ ){
            if( visy[i] ) lx[linker[i]] -= delta , ly[i] += delta;
            else slack[i] -= delta;
        }
        y = yy ;
        if( linker[y] == -1 ) break;
    }
    while( y ) linker[y] = linker[pre[y]] , y = pre[y];
}

void KM(){
    memset( lx , 0 ,sizeof(lx) );
    memset( ly , 0 ,sizeof(ly) );
    memset( linker , -1, sizeof(linker) );
    for( int i = 1 ; i <= n ; i++ ){
        memset( visy , false , sizeof(visy) );
        bfs(i);
    }
}

HDU2255
AC Code:
O(n^3)

#include 
#pragma comment(linker, “/STACK:1024000000,1024000000”)
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];  //可行性顶标
int linker[AX];  //记录匹配的边
int slack[AX];   //记录每个j相连的i的最小的lx[i]+ly[j]-w[i][j]
int n ;
bool dfs( int x ){
    visx[x] = true;
    for( int y = 1 ; y <= n ; y ++ ){
        if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
            visy[y] = true;
            if( linker[y] == -1 || dfs( linker[y] ) ){
                linker[y] = x ;
                return true;
            }
        }else if( slack[y] > lx[x] + ly[y] - w[x][y] ){//x,y不在相等子图且y不在增广路
            slack[y] = lx[x] + ly[y] - w[x][y];
        }
    }
    return false;
}

void KM(){
    memset( linker , -1 , sizeof(linker) );
    memset( ly , 0 , sizeof(ly) );
    for( int i = 1 ; i <= n ; i++ ){
        lx[i] = -INF;
        for( int j = 1 ; j <= n ; j++ ){
            if( lx[i] < w[i][j] ) lx[i] = w[i][j];
        }
    }
    for( int x = 1 ; x <= n ; x++ ){
        for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;//每次匹配x都要更新slack
        while(1){
            memset( visx , false , sizeof(visx) );
            memset( visy , false , sizeof(visy) );
            if( dfs(x) ){
                break;
            }else{ // 匹配失败后x一定在增广路,寻找不在增广路的j
                int delta = INF;
                for( int j = 1 ; j <= n ; j++ ){
                    if( !visy[j] && delta > slack[j] ){
                        delta = slack[j];
                    }
                }

                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ) lx[i] -= delta;
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visy[i] ) ly[i] += delta;
                    else slack[i] -= delta;
                    //修改顶标后,要把所有的slack值都减去delta
                     //slack[j] = min(lx[i] + ly[j] -w[i][j])
                     //在增广路的lx[i]减少,所以不在增广路的slack[j]减小
                }
            }
        }
    }
}

int main(){
    int x ;
    while( ~scanf("%d",&n) ){
        for( int i = 1 ; i <= n ; i++ ){
            for( int j = 1 ; j <= n ; j++ ){
                scanf("%d",&x);
                w[i][j] = x ;
            }
        }
        KM();
        int res = 0 ;
        for( int i = 1 ; i <= n ; i++ ){
            if( linker[i] != -1 ){
                res += w[linker[i]][i] ;
            }
        }
        printf("%d\n",res);
    }
    return 0 ;
}

TLE Code:
O(n^4)

#include 
#pragma comment(linker, “/STACK:1024000000,1024000000”)
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];
int delta ;
int linker[AX];
int n ;
bool dfs( int x ){
    visx[x] = true;
    for( int y = 1 ; y <= n ; y ++ ){
        if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
            visy[y] = true;
            if( linker[y] == -1 || dfs( linker[y] ) ){
                linker[y] = x ;
                return true;
            }
        }
    }
    return false;
}

void KM(){
    memset( linker , -1 , sizeof(linker) );
    memset( ly , 0 , sizeof(ly) );
    for( int i = 1 ; i <= n ; i++ ){
        lx[i] = -INF;
        for( int j = 1 ; j <= n ; j++ ){
            if( lx[i] < w[i][j] ) lx[i] = w[i][j];
        }
    }
    for( int x = 1 ; x <= n ; x++ ){
        while(1){
            memset( visx , false , sizeof(visx) );
            memset( visy , false , sizeof(visy) );
            if( dfs(x) ){
                break;
            }else{
                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ){
                        for( int j = 1 ; j <= n ; j++ ){
                            if( !visy[j] ){
                                delta = min( delta , lx[x]+ly[j]-w[i][j] );
                            }
                        }
                    }
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ) lx[i] -= delta;
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visy[i] ) ly[i] += delta;
                }
            }
        }
    }
}

int main(){
    int x ;
    while( ~scanf("%d",&n) ){
        for( int i = 1 ; i <= n ; i++ ){
            for( int j = 1 ; j <= n ; j++ ){
                scanf("%d",&x);
                w[i][j] = x ;
            }
        }
        KM();
        int res = 0 ;
        for( int i = 1 ; i <= n ; i++ ){
            if( linker[i] != -1 ){
                res += w[linker[i]][i] ;
            }
        }
        printf("%d\n",res);
    }
    return 0 ;
}

你可能感兴趣的:(二分图)