POJ 2400 Supervisor, Supervisee (KM + 回溯) - from lanshui_Yang

            题目大意:有 n 个老板 和 n 个员工 ,每个老板对每个员工都有一个满意度(范围 1 ~ n ),每个员工对每个老板也有一个满意度(范围1 ~ n ),但每个老板只能雇佣一个员工 , 每个员工也只能为一个老板工作,定义 :平均满意度 = ((每个人的满意度之和) - 2 * n )/ (2 * n) ,要求找出是平均满意度最小的分配方案,如果有多种方案,则按员工序号的字典序输出(即老板的编号始终是按 1 ~ n 输出,按每个老板对应员工的序号的字典序输出)。

        解题思路:这道题是找出所有的最小权值匹配 ,用到KM算法和回溯找全排列(注意剪枝!),具体请看程序。

        Ps:这道题的输入有问题,应先输入 第 二个 矩阵, 再输入 第 一个 矩阵 ,还有 找全排列的时候别忘剪枝!!

代码如下:

#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<vector>
#include<queue>
#define mem(a , b) memset(a , b , sizeof(a))
using namespace std ;
const int INF = 0x7fffffff ;
const int MAXN = 20 ;
int w[MAXN][MAXN] ;
int sp[MAXN][MAXN] , em[MAXN][MAXN] ;
int Lx[MAXN] , Ly[MAXN] ;
int link[MAXN] , linkx[MAXN] ;
int slack[MAXN] ;
bool visx[MAXN] , visy[MAXN] ;
bool vis[MAXN] ;
int flag ;
int tans ;
int sc ;
double ans ;
int ca ;
int n ;
void chu()
{
    mem(link , -1) ;
    mem(Lx , 0) ;
    mem(Ly , 0) ;
    mem(visx , 0) ;
    mem(visy , 0) ;
    sc = 0 ;
}
void init()
{
    scanf("%d" , &n) ;
    int i , j ;
    for(i = 1 ; i <= n ; i ++)
    {
        for(j = 1 ; j <= n ; j ++)
        {
            int tmp ;
            scanf("%d" , &tmp) ;
            em[i][tmp] = j ;  // 员工对老板的满意度
        }
    }
    for(i = 1 ; i <= n ; i ++)
    {
        for(j = 1 ; j <= n ; j ++)
        {
            int tmp ;
            scanf("%d" , &tmp) ;
            sp[i][tmp] = j ; // 老板对员工的满意度
        }
    }

    for(i = 1 ; i <= n ; i ++)  // 建图
    {
        for(j = 1 ; j <= n ; j ++)
        {
            w[i][j] = (sp[i][j] + em[j][i]) * (-1) ;
        }
    }
}
int dfs(int u)
{
    visx[u] = 1 ;
    int v ;
    for(v = 1 ; v <= n ; v ++)
    {
        if(visy[v])
            continue ;
        int t = Lx[u] + Ly[v] - w[u][v] ;
        if(t == 0)
        {
            visy[v] = 1 ;
            if(link[v] == -1 || dfs(link[v]))
            {
                link[v] = u ;
                linkx[u] = v ;
                return 1 ;
            }
        }
        else if(slack[v] > t)
        {
            slack[v] = t ;
        }
    }
    return 0 ;
}
void KM()
{
    int i , j ;
    int MAX = -INF ;
    for(i = 1 ; i <= n ; i ++)
    {
        for(j = 1 ; j <= n ; j ++)
        {
            if(w[i][j] > MAX)
                MAX = w[i][j] ;
        }
        Lx[i] = MAX ;
    }
    mem(Ly , 0) ;
    for(i = 1 ; i <= n ; i ++)
    {
        for(j = 1 ; j <= n ; j ++)
        {
            slack[j] = INF ;
        }
        while (1)
        {
            mem(visx , 0) ;
            mem(visy , 0) ;
            if(dfs(i))
                break ;
            int d = INF ;
            int k ;
            for(k = 1 ; k <= n ; k ++)
            {
                if(!visy[k] && d > slack[k])
                    d = slack[k] ;
            }
            for(k = 1 ; k <= n ; k ++)
            {
                if(visx[k])
                    Lx[k] -= d ;
                if(visy[k])
                    Ly[k] += d ;
                else
                {
                    slack[k] -= d ;
                }
            }
        }
    }
    tans= 0 ;
    for(i = 1 ; i <= n ; i ++)
    {
        tans += w[link[i]][i] ;
    }
    ans = (-1.0 * tans - 2 * n) / (2.0 * n) ;
}

void find(int cnt , int cost)         //全排列搜索找出所有答案
{
    if(cost < tans) return ;  // 此处剪枝很重要,不然会 TLE !!
    if(cnt > n)
    {
        if(tans != cost) return ;
        printf("Best Pairing %d\n", ++ sc);
        for(int i = 1 ; i <= n ; i ++)
            printf("Supervisor %d with Employee %d\n", i , linkx[i]);
    }
    else
    {
        for(int i = 1 ; i <= n ; i ++)
        {
            if(!vis[i])
            {
                vis[i]=1;
                linkx[cnt]=i;
                find(cnt + 1, cost + w[cnt][i]);
                vis[i]=0;
            }
        }
    }
}
void solve()
{
    KM() ;
    if(flag)
    {
        flag = 0 ;
    }
    else
        puts("") ;
    printf("Data Set %d, Best average difference: %.6f\n" , ++ ca , ans) ;
    mem(vis , 0) ;
    find(1 , 0) ;
}
int main()
{
    int T ;
    scanf("%d" , &T) ;
    ca = 0 ;
    flag = 1 ;
    while (T --)
    {
        chu() ;
        init() ;
        solve() ;
    }
    return 0 ;
}



你可能感兴趣的:(回溯,KM算法)