作者:来知晓
公众号:来知晓
刷题交流QQ群:444172041
Git项目地址:LeetCodeUsingC刷题笔记
本篇解法思路参考了labuladong的算法小抄里的文章《一个方法团灭 LeetCode 股票买卖问题》,将原来C++代码,根据C代码实现做了部分调整和简化。总体来看,东哥文章里面思路讲的已经很清楚了,下面总结下C代码实现对应框架时,个人认为的一些核心要点和注意事项。
关键点
若假设一共有n
天股价数据,最大交易次数限制为K
次,则0 <= i < n
和0 <= k <= K
,可以推出:
状态转移方程:
dp[i][k][0] = max(dp[i - 1][k][0], dp[i - 1][k][1] + price[i]);
dp[i][k][1] = max(dp[i - 1][k][1], dp[i - 1][k - 1][0] - price[i]);
解释:
max
取结果,是因为dp存的一直是到今天截止当前状态对应的最大利润。所以我们的输出才可以是dp[n - 1][K][0]。初始条件Base Case:
dp[-1][k][0] = 0; // 注意此处的k表示0到k的操作次数所有值
dp[-1][k][1] = -Infinity; // 表示不可能
dp[i][0][0] = 0; // 注意此处的i表示0到i-1的天数所有值
dp[i][0][1] = -Infinity;
输出的最大利润:
dp[n - 1][k][0] // 最后一天,手上不持有股票的时候有最大利润
以上为核心框架分析,下面开始实战。
K = 1
,即仅允许交易一次的情况,因为我们最终只关心dp[n - 1][K][0]
时的输出,所以状态转移方程中k
可以直接取1,而dp[i - 1][0][0]
恒为0,所以可以用k = 1
一直递推,从而降维得到:
dp[i][0] = max(dp[i - 1][0], dp[i - 1][1] + prices[i])
dp[i][1] = max(dp[i - 1][1], -prices[i])
K = +infinity
,即不限交易次数的情况,每天的dp里显然只需要存储最大交易次数对应的利润即可。而我们只关心最终dp[n - 1][K][0]
的输出,且K无穷大时,dp[i - 1][K - 1][0] = dp[i - 1][K][0]
,所以状态转移方程可以化简为:
dp[i][0] = max(dp[i - 1][0], dp[i - 1][1] + prices[i])
dp[i][1] = max(dp[i - 1][1], dp[i - 1][0] - prices[i])
K = 2
,即仅能交易两次,注意以下两点:
容易想到,冷冻期一天的限制条件可转化为:一旦发现卖出,则在下一天循环时跳过。以下为对应代码实现:
#include
// 采用状态机来遍历所有可能的状态
#define max(a, b) (a) > (b) ? (a) : (b)
int maxProfit(int* prices, int pricesSize)
{
// k = +infinity
int dp_i_0 = 0;
int dp_i_1 = INT_MIN;
int freeze_flag = 0;
int i;
for (i = 0; i < pricesSize; i++) {
// dp[i][k][0] = max(dp[i - 1][k][0], dp[i - 1][k][1] + price[i]); // 不动或卖出
// dp[i][k][1] = max(dp[i - 1][k][1], dp[i - 1][k - 1][0] - price[i]); // 不动或买入
if (freeze_flag == 1) {
freeze_flag = 0;
continue; // 卖出后冻结一天
}
int tmp_dp_i_0 = dp_i_0;
// dp_i_0 = max(dp_i_0, dp_i_1 + prices[i]);
if (dp_i_1 + prices[i] > dp_i_0) { // 若卖出
dp_i_0 = dp_i_1 + prices[i];
freeze_flag = 1;
}
dp_i_1 = max(dp_i_1, tmp_dp_i_0 - prices[i]);
}
return dp_i_0;
}
但是以上代码在过用例{1,2,4}
时,输出错误。分析发现,状态转移方程更新的不对,状态转移方程中这句:
dp[i][k][1] = max(dp[i - 1][k][1], dp[i - 1][k - 1][0] - price[i])
应该是dp[i - 2][k - 1][0]
。修改上面19-25行代码为:
int tmp_dp_i_0 = dp_i_0;
// dp_i_0 = max(dp_i_0, dp_i_1 + prices[i]);
if (dp_i_1 + prices[i] > dp_i_0) { // 若卖出
dp_i_0 = dp_i_1 + prices[i];
freeze_flag = 1;
}
dp_i_1 = max(dp_i_1, pre_dp_i_0 - prices[i]);
pre_dp_i_0 = tmp_dp_i_0;
再测试时,依然输出结果不对。深入分析后,发现前面说的转化思路不对,continue
会导致无法迭代每一步转移矩阵,从而无法得到所有可能情况。
比如输入{1,2,4}时,在不持有的可能中,第二天就卖了,检测到卖了标志,第三天就冻结了。而实际最优的结果不是取的第二天就卖,而是取的第二天持有,然后第三天卖这样的操作,才能获得利润最大化。所以代码根据其中的一种卖出可能,就当做真的已经卖出跳过其他可能的遍历了。
正确做法如下:
i-1
变为 i-2
正确代码如下:
#include
#define max(a, b) (a) > (b) ? (a) : (b)
int maxProfit(int* prices, int pricesSize)
{
// k = +infinity
int dp_i_0 = 0;
int dp_i_1 = INT_MIN;
int i;
int pre_dp_i_0 = 0;
for (i = 0; i < pricesSize; i++) {
// dp[i][k][0] = max(dp[i - 1][k][0], dp[i - 1][k][1] + price[i]); // 不动或卖出
// dp[i][k][1] = max(dp[i - 1][k][1], dp[i - 2][k - 1][0] - price[i]); // 不动或买入
int tmp = dp_i_0; // 更新前
dp_i_0 = max(dp_i_0, dp_i_1 + prices[i]); // 当前更新后
dp_i_1 = max(dp_i_1, pre_dp_i_0 - prices[i]);
pre_dp_i_0 = tmp; // 下次循环用时,就变成了前两天的
}
return dp_i_0;
}
此题要注意初始条件会从INT_MIN再减个手续费,可能导致整数类型溢出。需要注意防止超过int最小值, 如 INT_MIN + 1 - 2
就溢出了。
解决的小技巧:
// 原来
dp_i_0 = max(dp_i_0, dp_i_1 + prices[i] - fee);
// 改为
int tmp = ((long long)dp_i_1 + prices[i] - fee < INT_MIN) ?
INT_MIN : (dp_i_1 + prices[i] - fee);
dp_i_0 = max(dp_i_0, tmp);
完整代码如下:
// 采用状态机来遍历
/**** 模板 **********
状态转移方程:
dp[i][k][0] = max(dp[i - 1][k][0], dp[i - 1][k][1] + price[i]);
dp[i][k][1] = max(dp[i - 1][k][1], dp[i - 1][k - 1][0] - price[i]);
Base Case:
dp[-1][k][0] = 0;
dp[-1][k][1] = -Infinity; //不可能
dp[i][0][0] = 0;
dp[i][0][1] = -Infinity;
output:
dp[n - 1][k][0]
********************/
#include
#define max(a, b) (a) > (b) ? (a) : (b)
int maxProfit(int* prices, int pricesSize, int fee)
{
// k = +infinity
int dp_i_0 = 0;
int dp_i_1 = INT_MIN;
int i;
for (i = 0; i < pricesSize; i++) {
// dp[i][k][0] = max(dp[i - 1][k][0], dp[i - 1][k][1] + price[i] - fee); // 卖出时才扣手续费
// dp[i][k][1] = max(dp[i - 1][k][1], dp[i - 1][k - 1][0] - price[i]);
int tmp_dp_i_0 = dp_i_0;
// 注意防止超过int最小值, 如 INT_MIN + 1 - 2 就溢出了
int tmp = ((long long)dp_i_1 + prices[i] - fee < INT_MIN) ? INT_MIN : (dp_i_1 + prices[i] - fee);
dp_i_0 = max(dp_i_0, tmp);
dp_i_1 = max(dp_i_1, tmp_dp_i_0 - prices[i]);
}
return max(dp_i_0, dp_i_1);
}
完整C代码实现,可去GitHub仓:LeetCodeUsingC刷题笔记查看。