带有独立插头的dp,状态用4进制数表示,掌握位运算技巧还是蛮好写的,虽然很慢。。。
#include<cstdio> #include<cstring> #include<cstdlib> #include<cmath> #include<algorithm> #include<iostream> #define maxn 100010 #define inf 1000000000 using namespace std; int n,m,now,pre,num,ans; int hash[2][maxn],head[10010],to[maxn],next[maxn],tot[2],f[2][maxn]; int a[110][110],bit[50]; void add(int s,int x) { int pos=s%10000; for (int p=head[pos];p;p=next[p]) if (hash[now][to[p]]==s) { f[now][to[p]]=max(f[now][to[p]],x); return; } tot[now]++; hash[now][tot[now]]=s; f[now][tot[now]]=x; to[++num]=tot[now];next[num]=head[pos];head[pos]=num; } int find_pos(int s,int i) { return (s/(1<<bit[i-1]))%4; } int find_l(int s,int k) { int cnt=0; for (int i=k;i>=1;i--) { int p=find_pos(s,i); if (p==2) cnt++; else if (p==1) cnt--; if (!cnt) return i; } } int find_r(int s,int k) { int cnt=0; for (int i=k;i<=m+1;i++) { int p=find_pos(s,i); if (p==1) cnt++; else if (p==2) cnt--; if (!cnt) return i; } } void dp() { now=1;pre=0; tot[now]=1; hash[now][1]=0; f[now][1]=0; for (int i=1;i<=n;i++) { for (int j=1;j<=tot[now];j++) hash[now][j]<<=2; for (int j=1;j<=m;j++) { swap(now,pre); for (int k=1;k<=tot[now];k++) f[now][k]=-1; num=0;tot[now]=0; memset(head,0,sizeof(head)); for (int k=1;k<=tot[pre];k++) { int s=hash[pre][k],num=f[pre][k]+a[i][j]; if (s>=(1<<bit[m+1])) continue; int p=find_pos(s,j),q=find_pos(s,j+1),e=s-(p<<bit[j-1])-(q<<bit[j]); if (!p && !q) { add(e,num-a[i][j]); add(e+(1<<bit[j-1])+(2<<bit[j]),num); add(e+(3<<bit[j-1]),num); add(e+(3<<bit[j]),num); } else if (!p) { if (q==1) { add(e+(1<<bit[j-1]),num); add(e+(1<<bit[j]),num); add(e^(1<<bit[find_r(s,j+1)-1]),num); } else if (q==2) { add(e+(2<<bit[j-1]),num); add(e+(2<<bit[j]),num); add(e^(2<<bit[find_l(s,j+1)-1]),num); } else { add(e+(3<<bit[j-1]),num); add(e+(3<<bit[j]),num); if (!e) ans=max(ans,num); } } else if (!q) { if (p==1) { add(e+(1<<bit[j-1]),num); add(e+(1<<bit[j]),num); add(e^(1<<bit[find_r(s,j)-1]),num); } else if (p==2) { add(e+(2<<bit[j-1]),num); add(e+(2<<bit[j]),num); add(e^(2<<bit[find_l(s,j)-1]),num); } else { add(e+(3<<bit[j-1]),num); add(e+(3<<bit[j]),num); if (!e) ans=max(ans,num); } } else if (p==1 && q==1) add(e^(3<<bit[find_r(s,j+1)-1]),num); else if (p==2 && q==2) add(e^(3<<bit[find_l(s,j)-1]),num); else if (p==2 && q==1) add(e,num); else if (p==3 && q==1) add(e^(1<<bit[find_r(s,j+1)-1]),num); else if (p==3 && q==2) add(e^(2<<bit[find_l(s,j+1)-1]),num); else if (p==1 && q==3) add(e^(1<<bit[find_r(s,j)-1]),num); else if (p==2 && q==3) add(e^(2<<bit[find_l(s,j)-1]),num); else if (p==3 && q==3) { if (!e) ans=max(ans,num); } } } } } int main() { for (int i=1;i<=20;i++) bit[i]=i*2; scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) { scanf("%d",&a[i][j]); ans=max(ans,a[i][j]); } dp(); printf("%d\n",ans); return 0; }