题意:有一个N*M的01矩阵a,已知每一行有多少个1以及每一列有多少1。现在这个矩阵搞丢了,但是会告诉你a[i,j]是1的概率p[i,j](一个[0,100)的整数表示百分率)。让你还原出一个概率最大的符合条件的01矩阵,任意输出一个。
据说这题只能用zkw费用流过。。这个太奇怪了。
这题调了我一个下午,很有启发意义。
首先这是经典的矩阵还原模型,就是行做X部,列做Y部,之间的连边的流量代表a[i,j]的值,之前做过两道类似的题,因此这里还是想到了的。
然后就是最精华的地方了,概率应该是所有的p[i,j]的乘积,但是费用流的模型是和,我开始试着改一下费用流的写法,就是把中间找最短路的地方改成乘法,这样理论上可行但是乘不到几个数就会严重暴long long了。。然后我就上网搜了,我惊讶地发现竟然没几个人写这道题的题解,并且没人贴程序。这题有这么偏门吗。。找了半天才找到,要把费用设为概率的对数!这个真是太巧妙了,因为两个正数的大小关系完全等同于他们取对数的大小,然后两个数相乘就对应了他们的对数相加!!!然后要注意因为是最大费用,将他们费用全部取反,最后再取反回来即可。
然后我就无脑地写了一发,然后就交了,结果只过了前四个点,后面全T了。。然后我就各种调,我开始以为是连边有问题产生了负权环,然后下了组数据写了个SPFA判定,结果没负环。然后我就翻来覆去看了几遍费用流的模板,觉得没问题,,然后就抓狂了,在各种细节上试,,最后终于试出来了,就是update里面判是否是最短路径上的点那个dis的判定我直接用的等号,但事实上由于精度问题转移几次之后该相等的已经不相等了,于是我写了个eps和自定义了个等号。。好消息,没T了,但是都WA了,,,然后我又找了半天,觉得好像取对数的时候精度不好,于是把概率全部乘了很大一个数之后再取的对数,然后终于过了。以前openjudge上面有道题就是,要先乘一个数再做除法,不然精度要爆。
这题没看到别人写的代码,也完全不知道精度会有这种问题,全是自己摸索的,写出来还是很愉快。
这道题的启发:
1、比较乘积但是太大了,可以换成取对数之后的加法。
2、像什么取对数,开根号之类容易挂精度的,可以先乘一个很大的数。
3、系统的对数函数好像并不是O(1)的(或者是常数有点大)。系统自带的有log,log2,log10,分别是以e,2,10为底,其中以2位底的会比另外两个快三倍,如果只是用来比大小,最后用log2。
#include<cstdio> #include<cstring> #include<cassert> #include<cmath> #include<algorithm> using namespace std; #define DB double #define clr(a) memset(a,0,sizeof a) inline DB min(const DB&a, const DB&b) { return a < b ? a : b; } #define rep(a,b,c) for (int a=b;a<=c;++a) const int inf = 0x3f3f3f3f; const int MAXN = 205; const int MAXM = 100000; int N, M; const DB eps = 1e-10; bool cmp(DB a, DB b) //自己定义在精度误差范围内的等号 { if (a < b) swap(a, b); return a-b <= eps; } struct Ed { int to, cap; DB cost; Ed*nxt, *back; }; struct FlowNet { Ed Edge[MAXM], *ecnt, *adj[MAXN]; FlowNet () { ecnt=Edge; } DB dis[MAXN]; bool vis[MAXN]; int vn, S, T, flow; DB tot; inline void adde(int a, int b, int c, DB d) { (++ecnt)->to = b; ecnt->cap = c; ecnt->cost = d; ecnt->nxt = adj[a]; ecnt->back = ecnt+1; adj[a] = ecnt; (++ecnt)->to = a; ecnt->cap = 0; ecnt->cost = -d; ecnt->nxt = adj[b]; ecnt->back = ecnt-1; adj[b] = ecnt; } void init(int n, int s, int t) { tot = 0; flow = 0; vn = n; S = s; T = t; clr(dis), clr(adj); ecnt = Edge; } bool update() { DB tmp = 1e10; rep(i, 1, vn) if (vis[i]) for (Ed *p = adj[i]; p; p=p->nxt) if (p->cap > 0 && !vis[p->to]) tmp = min(tmp, dis[p->to]-dis[i] + p->cost); if (tmp == 1e10) return 0; for (int i = 1; i<=vn; ++i) if (vis[i]) dis[i] += tmp; return 1; } int aug(int u, int augco) { if (u == T) { tot += dis[S] * augco; flow += augco; return augco; } vis[u] = 1; int delta, augc = augco; for (Ed*p = adj[u]; p && augc; p=p->nxt) { int&v = p->to; if (!vis[v] && p->cap && cmp(dis[u],dis[v]+p->cost)) //写等号要出问题 { delta = min(p->cap, augc); delta = aug(v, delta); p->cap -= delta, p->back->cap += delta; augc -= delta; } } return augco - augc; } DB mcmf() { do { do clr(vis); while (aug(S, inf)); } while (update()); return tot; } void drawMap(int img[105][105]) { rep(i, 1, N) for (Ed*p = adj[i]; p; p=p->nxt) if (p->to > i && p->to<=N+M) img[i][p->to - N] = p->back->cap; } } G; int img[105][105]; int main() { scanf("%d%d", &N, &M); G.init(N+M+2, N+M+1, N+M+2); DB tmp; rep(i, 1, N) rep(j, 1, M) { scanf("%lf", &tmp); G.adde(i, j+N, 1, -log2(tmp*1e6)); //调整精度 } int xx; rep(i, 1, N) { scanf("%d", &xx); G.adde(G.S, i, xx, 0); } rep(j, 1, M) { scanf("%d", &xx); G.adde(N+j, G.T, xx, 0); } G.mcmf(); G.drawMap(img); rep(i, 1, N) { rep(j, 1, M) printf("%d", img[i][j]); puts(""); } return 0; }