题意:给一个矩阵,让找一条简单路径使路径和最小。N<=100, M<=8。
如果矩阵全是正数的话枚举起点终点跑跑费用流就好了,省选的话一定会给这个部分分。
考虑用括号序列的插头DP,增加一个3号表示起点和终点那种插头。注意这个1表示左括号,2表示右括号,3表示其它这个是有讲究的,可以用xor来很方便地转化。
分类讨论比较繁琐,参考了一下claris的才完整地写出来。但是有些细节的处理上要注意和回路的区别。还有就是要注意括号端点和插头3的转化,我只能感觉应该这样做,但是它的本质还是不太理解。
#include<iostream> #include<algorithm> #include<cstdio> #include<cstring> #define rep(i,a,b) for(int i=a;i<=b;++i) #define erp(i,a,b) for(int i=a;i>=b;--i) #define getbit(x,y) (((x) >> ((y)<<1)) & 3) #define bit(x,y) ((x)<<((y)<<1)) #define clrbit(x,i,j) ((x) & (~(3<<((i)<<1))) & (~(3<<((j)<<1)))) #define LL long long using namespace std; const int mo = 10003, MAXS = 100000, inf = 0x3f3f3f3f; void gmax(int&a, int b) { a<b?a=b:0; } int N, M, ans; int a[105][10]; struct Node { int s, nxt, val; }; struct Hash { Node e[MAXS]; int adj[mo], ec; void init() { memset(adj, -1, sizeof adj); ec = 0; } void push(int s, int v) { int ha = s%mo; for (int i = adj[ha]; ~i; i=e[i].nxt) if (e[i].s == s) return gmax(e[i].val, v); e[ec].val = v, e[ec].s = s; e[ec].nxt = adj[ha]; adj[ha] = ec++; } } dp[2]; inline int FindL(int st, int x) { int cnt = 1, s; erp(i, x-1, 0) { s = (st >> (i<<1)) & 3; if (s == 2) cnt++; else if (s == 1) cnt--; if (!cnt) return i; } return -1; } inline int FindR(int st, int x) { int cnt = 1, s; rep(i, x+1, M) { s = (st >> (i<<1)) & 3; if (s == 1) cnt++; else if (s == 2) cnt--; if (!cnt) return i; } return -1; } void work(int i, int j, int cur) { dp[cur].init(); rep(k, 0, dp[cur^1].ec-1) { int lass = dp[cur^1].e[k].s; if (lass >= (1<<((M+1)<<1))) continue; int L = getbit(lass, j-1); int U = getbit(lass, j); int s = clrbit(lass, j-1, j), w = a[i][j]; LL las = dp[cur^1].e[k].val; if (!L && !U) { dp[cur].push(s, las); dp[cur].push(s | bit(1, j-1) | bit(2, j), las + w); dp[cur].push(s | bit(3, j-1), las + w); dp[cur].push(s | bit(3, j), las + w); } else if (!L || !U) { int t = L+U; dp[cur].push(s | bit(t, j-1), las + w); dp[cur].push(s | bit(t, j), las + w); if (t == 3) { if (!s) gmax(ans, las + w); } else { if (L==1) dp[cur].push(s ^ bit(L, FindR(s, j-1)), las + w); if (L==2) dp[cur].push(s ^ bit(L, FindL(s, j-1)), las + w); if (U==1) dp[cur].push(s ^ bit(U, FindR(s, j)), las + w); if (U==2) dp[cur].push(s ^ bit(U, FindL(s, j)), las + w); } } else if (L==1 && U==1) dp[cur].push(s^bit(3, FindR(s, j)), las + w); else if (L==2 && U==1) dp[cur].push(s, las + w); else if (L==2 && U==2) dp[cur].push(s^bit(3, FindL(s, j-1)), las + w); else if (L==3 && U==3) { if (!s) gmax(ans, las + w); } else if (L==3) { if (U==1) dp[cur].push(s^bit(U, FindR(s, j)), las + w); else dp[cur].push(s^bit(U, FindL(s, j)), las + w); } else if (U==3) { if (L==1) dp[cur].push(s^bit(L, FindR(s, j-1)), las + w); else dp[cur].push(s^bit(L, FindL(s, j-1)), las + w); } } } int solve() { dp[0].init(), dp[0].push(0, 0); int cur = 0; rep(i, 1, N) { rep(k, 0, dp[cur].ec-1) dp[cur].e[k].s <<= 2; rep(j, 1, M) { cur ^= 1; work(i, j, cur); } } return ans; } int main() { scanf("%d%d", &N, &M); rep(i, 1, N) rep(j, 1, M) scanf("%d", &a[i][j]), gmax(ans, a[i][j]); printf("%d\n", solve()); return 0; }