题目:https://www.luogu.org/problem/P1006
文章前半部分讲常规解法,后半部分给出了采用滚动数组的时间复杂度O((n+m)*n*m)、空间复杂度O(n*m)的解法。
可以找两条路径,都是从(1,1)出发到(n,m),显而易见,两路径最后汇集于(n,m),那么之前的状态应该是两个不重合的点,因此引入状态d[x1][y1][x2][y2]。状态转移方程详见代码。
下面有两种解法:
解法一关键代码:
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
for(int k=1;k<=m;k++)
for(int t=1;t<=n;t++){
int ll=d[i][j-1][k][t-1];
int lu=d[i][j-1][k-1][t];
int ul=d[i-1][j][k][t-1];
int uu=d[i-1][j][k-1][t];
d[i][j][k][t]=max(max(ll,lu),max(ul,uu))+a[i][j]+a[k][t];
if(i==k&&j==t)d[i][j][k][t]-=a[k][t];
}
也就是说,需要就两个点是否重合进行讨论。
这里的困惑在于,最优答案中的两条路径在中间过程中会不会重合呢?是不会的!用图简单说明如下:
如图一所示,两条路径一为红二为蓝,产生交汇,则改成图二,因为矩阵所有点的值非负,所以图三的答案比图一的答案大。
于是就有了AC代码一。
解法二关键代码:
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
for(int k=1;k<=i;k++)
for(int t=j+1;t<=n;t++){
int ll=d[i][j-1][k][t-1];
int lu=d[i][j-1][k-1][t];
int ul=d[i-1][j][k][t-1];
int uu=d[i-1][j][k-1][t];
d[i][j][k][t]=max(max(ll,lu),max(ul,uu))+a[i][j]+a[k][t];
}
这里在设置循环变量k与t时作了限制,保证了点(k,t)始终不在点(i,j)的左边及下方,并且不重合。
最终答案应该是d[n][m-1][n-1][m]。
需要说明的是,这种解法的状态定义d[x1][y1][x2][y2],要求两点不能重合!而且它的子状态的个数远远少于解法一的子状态。
解法三:只用开三维数组 d[t][x1][x2]
因为点的横纵坐标和 t 取值属于[2,m+n],那么再引入两个点的横坐标x1与x2,则这两个点的纵坐标为t-x1与t-x2。当然需要注意纵坐标属于[1,n]。给出关键代码如下:
for(int i=1;i<=m+n;i++) //横纵坐标和
for(int j=1;j<=m;j++) //i-j即为第一点的纵坐标
for(int k=1;k<=m;k++) //i-k即为第二点的纵坐标
{
if(i-j<1||i-j>n||i-k<1||i-k>n) continue;//横坐标限定
int d1=max(d[i-1][j-1][k-1],d[i-1][j-1][k]);//上上,上左
int d2=max(d[i-1][j][k-1],d[i-1][j][k]);//左上,左左
d[i][j][k]=max(d1,d2)+a[j][i-j]+a[k][i-k];
if(j==k) d[i][j][k]-=a[j][i-j];
}
答案是d[m+n][n][m]。
再前进一步,在解法三的基础上,可以发现循环变量i可以滚动掉,看懂下面的关键代码需要对完全背包、01背包十分理解才行。
解法四关键代码:
for(int i=1;i<=m+n;i++) //横纵坐标和。顺序
for(int j=m;j>=1;j--) //i-j即为第一点的纵坐标。逆序
for(int k=m;k>=1;k--) //i-k即为第二点的纵坐标。逆序
{
if(i-j<1||i-j>n||i-k<1||i-k>n) continue;
int d1=max(d[j-1][k-1],d[j-1][k]);//上上,上左
int d2=max(d[j][k-1],d[j][k]);//左上,左左
d[j][k]=max(d1,d2)+a[j][i-j]+a[k][i-k];
if(j==k) d[j][k]-=a[j][i-j]; //路径点重合
}
AC代码一(对应解法一):
#include
#include
#include
#include
using namespace std;
int d[51][51][51][51];
int a[51][51];
int main()
{
int m,n;
scanf("%d%d",&m,&n);
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
for(int i=1;i<=m;i++) //阶段
for(int j=1;j<=n;j++)
for(int k=1;k<=m;k++) //状态
for(int t=1;t<=n;t++){
int ll=d[i][j-1][k][t-1];
int lu=d[i][j-1][k-1][t];
int ul=d[i-1][j][k][t-1];
int uu=d[i-1][j][k-1][t];
//决策
d[i][j][k][t]=max(max(ll,lu),max(ul,uu))+a[i][j]+a[k][t];
if(i==k&&j==t)d[i][j][k][t]-=a[k][t];
}
printf("%d\n",d[m][m][m][n]);
return 0;
}
AC代码二(对应解法二):
#include
#include
#include
#include
using namespace std;
int d[51][51][51][51];
int a[51][51];
int main()
{
int m,n;
scanf("%d%d",&m,&n);
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
for(int i=1;i<=m;i++)//阶段
for(int j=1;j<=n;j++)
for(int k=1;k<=i;k++)//状态
for(int t=j+1;t<=n;t++){
int ll=d[i][j-1][k][t-1];
int lu=d[i][j-1][k-1][t];
int ul=d[i-1][j][k][t-1];
int uu=d[i-1][j][k-1][t];
//决策
d[i][j][k][t]=max(max(ll,lu),max(ul,uu))+a[i][j]+a[k][t];
}
printf("%d\n",d[m][n-1][m-1][n]);
return 0;
}
AC代码三(解法三):
#include
#include
#include
#include
using namespace std;
int d[102][51][51];
int a[51][51];
int main()
{
int m,n;
scanf("%d%d",&m,&n);
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
for(int i=1;i<=m+n;i++) //横纵坐标和
for(int j=1;j<=m;j++) //i-j即为第一点的纵坐标
for(int k=1;k<=m;k++) //i-k即为第二点的纵坐标
{
if(i-j<1||i-j>n||i-k<1||i-k>n) continue;
int d1=max(d[i-1][j-1][k-1],d[i-1][j-1][k]);//上上,上左
int d2=max(d[i-1][j][k-1],d[i-1][j][k]);//左上,左左
d[i][j][k]=max(d1,d2)+a[j][i-j]+a[k][i-k];
if(j==k) d[i][j][k]-=a[j][i-j]; //路径点重合
}
printf("%d\n",d[m+n][m][m]);
return 0;
}
AC代码四(解法四):
#include
#include
#include
#include
using namespace std;
int d[51][51];
int a[51][51];
int main()
{
int m,n;
scanf("%d%d",&m,&n);
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
for(int i=1;i<=m+n;i++) //横纵坐标和。顺序
for(int j=m;j>=1;j--) //i-j即为第一点的纵坐标。逆序
for(int k=m;k>=1;k--) //i-k即为第二点的纵坐标。逆序
{
if(i-j<1||i-j>n||i-k<1||i-k>n) continue;
int d1=max(d[j-1][k-1],d[j-1][k]);//上上,上左
int d2=max(d[j][k-1],d[j][k]);//左上,左左
d[j][k]=max(d1,d2)+a[j][i-j]+a[k][i-k];
if(j==k) d[j][k]-=a[j][i-j]; //路径点重合
}
printf("%d\n",d[m][m]);
return 0;
}