【BZOJ】3996: [TJOI2015]线性代数

题意

给出一个\(N \times N\)的矩阵\(B\)和一个\(1 \times N\)的矩阵\(C\)。求出一个\(1 \times N\)的01矩阵\(A\),使得\[ D = ( A * B - C ) * A^T \]最大,其中\(A ^ T\)是矩阵\(A\)的转置。(\(n<=500\)

分析

好神的题。首先我们容易推出一个式子:
\[ D = \sum_{i=1}^{n} \sum_{j=1}^{n} a_i \times a_j \times b_{i, j} - \sum_{i=1}^{n} a_i \times c_i \]

题解

\(b_{i, j}\)则必须选\(a_i\)\(a_j\),而选\(a_i\)就必须选\(c_i\)
那么可以看成:
\(a_i\)为一个点,权值为\(-c_i\),表示选了\(a_i\)就必须选\(c_i\)
\(b_{i, j}\)为一个点,权值为\(b_{i, j}\),向\(a_i、a_j\)连有向边,表示选了\(b_{i, j}\)就必须选\(a_i、a_j\)
于是问题变成求最大权闭合子图....
由于某种原因,这个网络流跑的飞起?似乎成了二分图复杂度变成\(O(mn^{0.5})\)了?感觉不科学啊QAQ

#include <bits/stdc++.h>
using namespace std;
inline int getint() {
    int x=0, c=getchar();
    for(; c<48||c>57; c=getchar());
    for(; c>47&&c<58; x=x*10+c-48, c=getchar());
    return x;
}
const int N=605, vN=N*N+N, oo=~0u>>1;
int ihead[vN], cnt=1;
struct E {
    int next, to, cap;
}e[6*N*N+2*N];
inline void add(int x, int y, int cap) {
    e[++cnt]=(E){ihead[x], y, cap}; ihead[x]=cnt;
    e[++cnt]=(E){ihead[y], x, cap}; ihead[y]=cnt;
}
inline int min(const int &a, const int &b) {
    return a<b?a:b;
}
int isap(int s, int t, int n) {
    static int gap[vN], cur[vN], d[vN], p[vN];
    gap[0]=n;
    int r=0, x=s, i, f;
    for(; d[s]<n;) {
        for(i=cur[x]; i && !(e[i].cap && d[x]==d[e[i].to]+1); i=e[i].next);
        if(i) {
            p[e[i].to]=cur[x]=i;
            if((x=e[i].to)==t) {
                for(f=oo, x=t; x!=s; f=min(f, e[p[x]].cap), x=e[p[x]^1].to);
                for(r+=f, x=t; x!=s; e[p[x]].cap-=f, e[p[x]^1].cap+=f, x=e[p[x]^1].to);
            }
        }
        else {
            if(!--gap[d[x]]) break;
            d[x]=n;
            for(i=ihead[x]; i; i=e[i].next) {
                if(e[i].cap && d[x]>d[e[i].to]+1) {
                    d[x]=d[e[i].to]+1;
                    cur[x]=i;
                }
            }
            ++gap[d[x]];
            if(x!=s) x=e[p[x]^1].to;
        }
    }
    return r;
}
int main() {
    int n=getint(), sum=0, S, T;
    S=n*(n+1)+1, T=S+1;
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=n; ++j) {
            int id=i*n+j, w=getint();
            add(id, T, w);
            if(i!=j) add(i, id, oo);
            add(j, id, oo);
            sum+=w;
        }
    }
    for(int i=1; i<=n; ++i) {
        add(S, i, getint());
    }
    printf("%d\n", sum-isap(S, T, T));
    return 0;
}



你可能感兴趣的:(【BZOJ】3996: [TJOI2015]线性代数)