KM bfs写法

KM bfs写法

2018astar资格赛的第三题整数规划。

\(x, y\)看成二分图两边的顶标,\(a_{ij}\)就是二分图的边权,整道题其实就是求二分图的最大权匹配。
然后打了个\(dfs\)\(KM\)\(TLE\)了,后来听别人说要用\(bfs\)的写法,因为那个才是真正的\(O(n^3)\)\(dfs\)的写法最坏情况还是\(O(n^4)\)

原理是一样的,只不过\(bfs\)有一点点像迭代,每一次也只是搜\(diff=0\)的情况,而且右边的点只会搜索一次(或者说是左边的点只会搜索一次,即左边的每个点只会进队一次),用\(pre\)记住当前的交错路径,找到未匹配的就可以沿交错路径进行修改。

#include 
using namespace std;

typedef long long LL;
const int maxn=210;
const LL inf=1LL<<60;

int n;

namespace KM
{
    int n;
    LL mat[maxn][maxn];               //边权
    int matcha[maxn], matchb[maxn];   //左边的点匹配的右边点;右边的点匹配的左边点
    LL marka[maxn], markb[maxn];      //左顶标;右顶标
    LL slack[maxn];                   //松弛数组
    bool visa[maxn], visb[maxn];      //访问标记

    int head, tail;
    int q[maxn], pre[maxn];           //队列;交错路径

    bool check(int cur)
    {
        visb[cur]=true;     //标记cur已搜索
        if (matchb[cur])    //已匹配,即当前匹配失败
        {
            if (!visa[matchb[cur]])    //匹配的点是否已进队
            {
                q[++tail]=matchb[cur];
                visa[matchb[cur]]=true;
            }
            return false;
        }
        //未匹配,即当前匹配成功,沿交错路径进行匹配
        while (cur)
            swap(cur, matcha[matchb[cur]=pre[cur]]);
        return true;
    }

    void bfs(int start)
    {
        fill(visa, visa+1+n, false);
        fill(visb, visb+1+n, false);
        fill(slack, slack+1+n, inf);

        head=tail=1;
        q[1]=start;
        visa[start]=true;

        while (1)
        {
            while (head<=tail)
            {
                int cur=q[head++];
                for (int i=1; i<=n; ++i)
                {
                    LL diff=marka[cur]+markb[i]-mat[cur][i];
                    if (!visb[i] && diff<=slack[i])   //visb=true说明已搜索,无需更新slack和pre,也是保证pre的正确性
                    {
                        slack[i]=diff;
                        pre[i]=cur;
                        if (diff==0)  //diff=0,可以尝试匹配
                            if (check(i)) return; //匹配成功可直接返回
                    }
                }
            }

            LL delta=inf;
            for (int i=1; i<=n; ++i)
                if (!visb[i] && slack[i]) delta=min(slack[i], delta);
            for (int i=1; i<=n; ++i)    //松弛
            {
                if (visa[i]) marka[i]-=delta;
                if (visb[i]) markb[i]+=delta;
                else slack[i]-=delta;   //维护slack的正确性(参考diff的计算及marka,markb的变化)
            }

            head=1, tail=0;
            for (int i=1; i<=n; ++i)
                if (!visb[i] && !slack[i] && check(i)) return;
                //松弛后尝试匹配diff=0的点。
        }
    }

    void solve()
    {
        fill(matcha, matcha+1+n, 0);
        fill(matchb, matchb+1+n, 0);
        fill(markb, markb+1+n, 0);

        for (int i=1; i<=n; ++i)
        {
            marka[i]=0;
            for (int j=1; j<=n; ++j)
                marka[i]=max(marka[i], mat[i][j]);
        }

        for (int i=1; i<=n; ++i) bfs(i);
    }
}

void read()
{
    scanf("%d", &n);
    KM::n=n;
    for (int i=1; i<=n; ++i)
        for (int j=1; j<=n; ++j)
        {
            int x;
            scanf("%d", &x);
            KM::mat[i][j]=-x;
        }
}

void solve()
{
    KM::solve();
    LL ans=0;
    for (int i=1; i<=n; ++i)
        ans+=KM::marka[i]+KM::markb[i];
    printf("%lld\n", -ans);
}

int main()
{
    int casesum;
    scanf("%d", &casesum);
    for (int i=1; i<=casesum; ++i)
    {
        printf("Case #%d: ", i);
        read();
        solve();
    }
    return 0;
}

转载于:https://www.cnblogs.com/GerynOhenz/p/9458006.html

你可能感兴趣的:(KM bfs写法)