详解leetcode石头合并问题

https://leetcode-cn.com/problems/minimum-cost-to-merge-stones/

文章目录

      • 分析
        • 区间动态规划三部曲
      • 解法一
        • 思路
        • 细节
        • 代码
      • 解法二
        • 代码

分析

这道题是一道经典的区间dp问题,旨在通过动态规划去求一个区间的最优解,通过将大区间划分为很多个小区间,再由小区间的解来组合出大区间的解,这体现了分治的思想。

区间动态规划三部曲

  1. 定义状态:dp[i, j]为区间[i, j]的最优解
  2. 定义状态转移方程:最常见的写法为:dp[i,j] = max/min{dp[i,j], dp[i, k] + dp[k+1, j] + cost}。选取[i, j]之间的一个分界点k,分别计算[i, k]和[k+1, j]的最优解,从而组合出[i, j]的最优解。
  3. 初始化:dp[i][i] = 常数。区间长度为1时的最优解应当是已知的。

假设要求的区间最优解为dp[1, n],区间dp问题有两种编码方法:
第一种:

for (int i = n; i >= 1; --i) {
     
	for (int j = i + 1; j <= n; ++j) {
     
		for (int k = i; k < j; ++k) {
     
			dp[i,j] = max/min(dp[i,j], dp[i,k] + dp[k+1, j] + cost)
		}
	}
}

这种写法就是常规的dp写法,枚举i为子区间左边界,枚举j为子区间有边界,枚举k为分界点。要注意由于要求的是dp[1,n],所以i必须从大往小遍历,j必须从小往大遍历。这样在状态转移方程中利用的就是已求解的dp状态。
第二种:

for (int len = 2; len <= n; ++len) {
     
	for (int i = 1; i + len - 1  <= n; ++i) {
     
		int j = i + len - 1;
		for (int k = i; k < j; ++k) {
     
			dp[i,j] = max/min(dp[i,j], dp[i,k] + dp[k+1, j] + cost)
		}
	}
}

这种写法最常见,枚举len为区间长度,枚举i为区间左端点,由此可以计算出区间右端点j,枚举k为分界点。区间长度从2到n,跟上一种写法相同。这种写法的正确性可能不如上一种那么直观,它从小到大枚举出所有区间,在求解大区间时,状态转移方程中利用的状态都是小区间的状态,必定在它之前被求解,所以也是正确的。

解法一

思路

回到石头合并问题,如何把它转换成一个区间dp问题呢?

首先考虑简单的情况,如果一次只能合并两堆石头,如何求解?
直接跳到最后的情况,一定是只剩下两堆石头,我们再把它们合并成一堆石头,这一步的成本是多少?成本为sum(1, n),即所有石头数目之和。因为合并的成本是两堆石头数目之和,而这两堆石头的数目之和一定是所有石头数目之和。

将所有石头合并成两堆的成本是多少呢?我们可以将所有石头划分为左右两部分,则成本就是分别将左部分和右部分合并成一堆的成本。子问题就这样出现了,左右两部分就相当于划分出的两个小区间。

定义dp[i][j]为合并第i到第j堆石头为一堆的成本,则dp[i][j] = min(dp[i][p] + dp[p + 1][j] + sum(i, j)) | i <= p < jp为分界点。初始化dp[i][i] = 0,答案为dp[1][n]

再来考虑一次合并k堆的情况。最后一定是变成k堆石头,这一步合并的成本依然是这k堆石头数目之和,也即所有石头数目之和。将所有石头合并成k堆的成本是多少?

我们同样对所有石头进行划分,左部分最终要合并成1堆,而右部分最终要合并为k - 1堆。这样左部分就是一个子问题,而右部分又是一个变相的子问题(将它也继续划分为两部分来求解)。

定义dp[i][j][k]为合并第i堆到第j堆石头为k堆的成本,状态转移方程有关键两点:

  1. dp[i][j][1] = dp[i][j][k] + sum(i, j)。不能直接求出合并为1堆的成本,得先合并成k堆。
  2. dp[i][j][m] = min(dp[i][p][1] + dp[p + 1][j][m - 1])i <= p < j2 <= m <= k。这里m为堆数,不能直接用k是因为右部分的存在,要对右部分继续划分求解的话,子问题就必须有合并成m堆的情况。

初始化dp[i][i][1] = 0,答案就是dp[1][n][1]。对于无法实现的情况,定义dp[i][j][k] = max

细节

第一点:一定会有不能合并成1堆的情况,怎么排除掉这种情况呢?
如果能合并成1堆,就一定得先合并成k堆,这在前面已经讨论过了。这k堆里面的其中1堆,也是由k堆合并而来的,这样一直套娃,就能还原到原始的堆数n。我们由此可以定义一个方程:k + (k - 1) * a == n,a是一个大于等于0的整数。
推算一下,有:k - 1 + (k - 1) * a == n - 1 ⇒ \Rightarrow (k - 1) * (a + 1) == n - 1
所以对于有解的情况,一定有(n - 1) % (k - 1) == 0

第二点:为什么划分的方式是左部分合并成1堆,右部分合并成k-1堆?左部分k-1,右部分1;左部分2,右部分k-2…这些方式可行吗?
可行的划分方式只能是1和k-1,左右当然不重要

首先说明1和k-1能完整覆盖到所有情况:
如果对于dp[i][j][m],它的最优划分是dp[i][p][2] + dp[p + 1][j][m - 2]
那么dp[i][p][2] = dp[i][p1][1] + dp[p1 + 1][p][1]p1为最优划分点。
代入一下,就有dp[i][j][m] = dp[i][p1][1] + dp[p1 + 1][p][1] + dp[p + 1][j][m - 2]
后面那俩合并一下就是k-1堆的情况,所以说1和k-1的划分方式是正确的。

再说明为什么2和k-2的划分是错误的:
这一点要从递归的角度,自顶向下地来看就好理解。我们要求解的是solve(1, n, 1),由于堆数为1,所以会递归调用solve(1, n, k)。堆数为k,需要进行划分来求解,分别调用solve(1, p, 2)solve(p + 1, n, k - 2),p从1到n-1循环。当p == 1p == 2时我们都知道结果,但当3 <= p < n呢?solve(1, p, 2)不是一个初始状态,也不是可以划分的状态,也不知道是不是合法的状态,这就变成了一个无法求解的状态,所以划分是错误的。
再回到dp的角度,其实也就是dp[i][p][2]是无法求解的,合并成2堆不是一个子问题,而我们定义的划分方式又导致它无法继续分解为子问题,那它就肯定无法求解了。

第三点:枚举分界点时,step应该是k - 1而不是1
step为1当然也是正确的,但是却进行了很多无用的计算,导致运行时间增加。为什么step可以是k-1呢?因为我们设计的划分是将左部分区间[i, p]合并为1堆,那就一定有(p - i) % (k - 1) == 0,结合最初p = i,就可以知道step应该是k-1,这样会涵盖所有有效的分界点p。对于其他的分界点p,左部分不能合并为1堆,那这样的划分并没有意义,对于计算答案也就没有帮助了。

代码

class Solution {
     
    // 不用Integer.MAX_VALUE,因为Integer.MAX_VALUE + 正数 会溢出变为负数
    private int MAX_VALUE = 99999999; 
    public int mergeStones(int[] stones, int k) {
     
        int n = stones.length;
        if ((n - 1) % (k - 1) != 0) return -1;
        int[][][] dp = new int[n + 1][n + 1][k + 1];
        int[] sum = new int[n + 1];
        for (int i = 1; i <= n; ++i) sum[i] = sum[i - 1] + stones[i - 1];
        for (int i = 1; i <= n; ++i) {
     
            for (int j = i; j <= n; ++j) {
     
                for (int m = 2; m <= k; ++m) dp[i][j][m] = MAX_VALUE;
            }
            dp[i][i][1] = 0;
        }
        for (int len = 2; len <= n; ++len) {
      // 枚举区间长度
            for (int i = 1; i + len - 1 <= n; ++i) {
      // 枚举区间起点
                int j = i + len - 1;
                for (int m = 2; m <= k; ++m) {
      // 枚举堆数
                    for (int p = i; p < j; p += k - 1) {
       // 枚举分界点
                        dp[i][j][m] = Math.min(dp[i][j][m], dp[i][p][1] + dp[p + 1][j][m - 1]);
                    }
                }
                dp[i][j][1] = dp[i][j][k] + sum[j] - sum[i - 1];
            }
        }
        return dp[1][n][1];
    }
}

参考https://leetcode.com/problems/minimum-cost-to-merge-stones/discuss/247657/JAVA-Bottom-Up-%2B-Top-Down-DP-With-Explaination

解法二

上述解法的时间复杂度是 O ( n 3 ∗ k ) O(n^3*k) O(n3k),我们可以对它进行优化。
定义dp[i][j]尽可能多的合并区间[i, j] 所需的成本,不一定能合并成一堆,但合并完成后剩下的堆数一定小于k,更具体地,剩余的堆数一定是(n - 1) % (k - 1) + 1
证明
已知一次合并会导致堆数减少k-1,假设最多进行了a次合并,则有
remain = n - (k - 1) * a1 <= remain <= k - 1
⇒ \Rightarrow remain - 1 = n - 1 - (k - 1) * a
⇒ \Rightarrow remain - 1 = (n - 1) % (k - 1)
⇒ \Rightarrow remain = (n - 1) % (k - 1) + 1
证毕。

我们参照解法一来定义状态转移方程,同样将区间[i,j]划分为两部分。
我们保证将左部分合并成1堆,而尽可能多地合并右部分。(左部分需要满足(len - 1) % (k - 1) == 0)。
右部分剩余堆数满足1 <= remain <= k - 1,如果最后右部分剩余k-1堆(也即(j - i) % (k - 1) == 0),则还可以继续将这两部分合并成1堆。
因此合并区间[i,j]的成本是合并其左右部分成本之和(对于最优的划分)。如果可以进一步合并的话,则额外的成本是sum(i, j)
状态转移方程为:dp[i][j] = min(dp[i][p] + dp[p + 1][j]), i <= p < j,如果可以继续合并,dp[i][j] += sum(i, j)

这样的话枚举的区间长度就必须从k开始了,因为长度在[1,k-1]之间的区间已经无法进行合并了,它们的dp[i][j] == 0

代码

class Solution {
     
    public int mergeStones(int[] stones, int k) {
     
        int n = stones.length;
        if ((n - 1) % (k - 1) != 0) return -1;
        int[][] dp = new int[n + 1][n + 1];
        int[] sum = new int[n + 1];
        for (int i = 1; i <= n; ++i) sum[i] = sum[i - 1] + stones[i - 1];
        for (int len = k; len <= n; ++len) {
      // 枚举区间长度
            for (int i = 1; i + len - 1 <= n; ++i) {
      // 枚举区间起点
                int j = i + len - 1;
                dp[i][j] = Integer.MAX_VALUE;
                for (int p = i; p < j; p += k - 1) {
       // 枚举分界点
                    dp[i][j] = Math.min(dp[i][j], dp[i][p] + dp[p + 1][j]);
                }
                if ((j - i) % (k - 1) == 0) dp[i][j] += sum[j] - sum[i - 1];
            }
        }
        return dp[1][n];
    }
}

你可能感兴趣的:(算法)