uoj #80. 二分图最大权匹配 KM算法

题意

求二分图最大权匹配。
n<=400,m<=160000,val<=10^9

分析

妈妈我终于学会了KM算法系列。
一开始去网上找了个标,改了改交上去结果T了。。。然后就自己yy了一个模板出来。
具体的算法实现可以去看15年的论文。

代码

#include
#include
#include
#include
#include
using namespace std;

typedef long long LL;

const int N=405;
const LL inf=(LL)1e16;

int n,m,ans[N],match[N],k;
bool vx[N],vy[N];
LL lx[N],ly[N],map[N][N],slack[N];

int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

bool aug(int x)
{
    vx[x]=1;
    for (int i=1;i<=n;i++)
    {
        if (vy[i]||lx[x]+ly[i]-map[x][i]) continue;
        vy[i]=1;
        if (!match[i]||aug(match[i]))
        {
            match[i]=x;
            ans[x]=i;
            return 1;
        }
    }
    return 0;
}

LL km()
{
    for (int i=1;i<=n;i++)
    {
        memset(vx,0,sizeof(vx));
        memset(vy,0,sizeof(vy));
        if (aug(i)) continue;
        for (int j=1;j<=n;j++) slack[j]=inf;
        for (int x=1;x<=n;x++)
            if (vx[x])
                for (int y=1;y<=n;y++)
                    if (!vy[y]) slack[y]=min(slack[y],lx[x]+ly[y]-map[x][y]);
        while (1)
        {
            LL mn=inf;int to,s;
            for (int j=1;j<=n;j++) if (!vy[j]) mn=min(mn,slack[j]);
            for (int j=1;j<=n;j++)
            {
                if (vx[j]) lx[j]-=mn;
                if (vy[j]) ly[j]+=mn;
                else slack[j]-=mn,to=!slack[j]?j:to;
            }
            if (!match[to]) break;
            s=match[to];vy[to]=vx[s]=1;
            for (int j=1;j<=n;j++) if (!vy[j]) slack[j]=min(slack[j],lx[s]+ly[j]-map[s][j]);
        }
        memset(vx,0,sizeof(vx));
        memset(vy,0,sizeof(vy));
        aug(i);
    }
    LL ans=0;
    for (int i=1;i<=n;i++) ans+=lx[i]+ly[i];
    return ans;
}

int main()
{
    n=read();m=read();k=read();
    for (int i=1;i<=k;i++)
    {
        int x=read(),y=read();LL z=read();
        map[x][y]=max(map[x][y],z);
        lx[x]=max(lx[x],z);
    }
    int tmp=n;n=max(n,m);
    printf("%lld\n",km());
    for (int i=1;i<=tmp;i++) printf("%d ",map[i][ans[i]]?ans[i]:0);
    return 0;
}

你可能感兴趣的:(KM算法)