给定一系列矩阵
由于矩阵乘法要求两个相乘矩阵的维度满足:第一个矩阵的列数要与第二个矩阵的行数相同。所以我们只要用 N+1 个数字就能表示所有矩阵的维度了,这里我们用 d 来表示这 N+1 个数字, 其中 di 和 di+1 分别表示第 i 个矩阵的行数和列数。
给定一个矩阵序列 A , 我们并不需要真正计算矩阵乘法,而是给出最优时间复杂度和矩阵相乘顺序。因此,我们真正的输入是 d 。这里我们暂且不考虑输出矩阵相乘的顺序,先以求最优时间复杂度为目标解决这个优化问题。
如果你不想看分析过程可以直接看后面的算法实现部分。
如果我们用 C(⋅) 表示某个矩阵连乘序列的最优时间复杂度,那么它一定满足下面的公式:
利用动态规划的思想解决矩阵序列连乘问题的算法本身的时间复杂度跟 B 矩阵的计算有关, B 矩阵需要计算其整个上三角部分,我们逐步推导:
第 1 列: 计算 B(0,1) : 需要 1−0=1 次计算。
第 2 列: 计算 B(1,2) : 需要 2−1=1 次计算。
第 2 列: 计算 B(0,2) : 需要 2−0=2 次计算。
第 3 列: 计算 B(2,3) : 需要 3−2=1 次计算。
第 3 列: 计算 B(1,3) : 需要 3−1=2 次计算。
第 3 列: 计算 B(0,3) : 需要 3−0=3 次计算。
⋮
第 N−1 列: 计算 B(N−2,N−1) : 需要 (N−1)−(N−2)=1 次计算。
第 N−1 列: 计算 B(N−3,N−1) : 需要 (N−1)−(N−3)=2 次计算。
第 N−1 列: 计算 B(N−4,N−1) : 需要 (N−1)−(N−4)=3 次计算。
⋮
第 N−1 列: 计算 B(0,N−1) : 需要 (N−1)−(0)=N−1 次计算。
所以计算时间复杂度为:
虽然算法时间复杂度为 O(N3) , 我们只需要存储一个矩阵就可以了,所以空间复杂度是 O(N2) 。
完整的C++实现如下:
#include
#include
using namespace std;
// 寻找最优时间复杂度 B,以及最优划分 K
void find_best_complexity(vector<int> &B, vector<int> &K, const int *d, int N){
B.resize(N*N);
K.resize(N*N);
for (int i = 0; i < N; i++){
B[i*N + i] = 0;
}
for (int v = 1; v < N; v++){
for (int u = v - 1; u > -1; u--){
int best_cmp = INT_MAX;
int best_k;
for (int k = u; k < v; k++){
int current_cmp = d[u] * d[k + 1] * d[v + 1] + B[u*N + k] + B[(k+1)*N + (v)];
if (current_cmp < best_cmp){
best_cmp = current_cmp;
best_k = k;
}
}
K[u*N + v] = best_k;
B[u*N + v] = best_cmp;
}
}
}
// 输出最优时间复杂度下矩阵的相乘顺序
void print_uv(int u, int v, vector<int> &K, int &N){
if (u==v){
return;
}
int k = K[u*N+v];
print_uv(u, k, K, N);
print_uv(k + 1, v, K, N);
printf("%4d ", u);
printf("%4d ", k);
printf("%4d \n", v);
}
// 举个例子
int main(int argc, char** argv){
int d[] = {1, 2, 3, 1, 5};
int N = (sizeof(d) / sizeof(d[0])) - 1;
vector<int> B;
vector<int> K;
find_best_complexity(B, K, d, N);
printf("###############################################################\n");
printf("# B\n");
for (int u = 0; u < N; u++){
for (int v = 0; v < N; v++){
if (v < u){
printf("%4d ", -1);
}
else{
printf("%4d ", B[u*N + v]);
}
}
printf("\n");
}
printf("###############################################################\n");
printf("# K\n");
for (int u = 0; u < N; u++){
for (int v = 0; v < N; v++){
if (v < u){
printf("%4d ", -1);
}
else{
printf("%4d ", K[u*N + v]);
}
}
printf("\n");
}
printf("###############################################################\n");
printf("# order\n");
print_uv(0, N - 1, K, N);
return EXIT_SUCCESS;
}
上述代码用printf
是为了更好的格式化输出。