谈使用动态规划算法解矩阵连乘问题

 一、题目描述

解矩阵连乘问题

二、样例输入

 6

 30 35 15 5 10 20 25

三、样例输出

 ((A0(A1A2))((A3A4)A5))

四、思路分析

首先大概说一下样例输入,6 代表的是六个矩阵相乘,第二行的数据中每三个连续的数据代表两个矩阵,比如 {30,35,15} 等价于矩阵 A1 = 30x35,A2 = 35x15。

首先关于这道题应该很多朋友在面试中都遇到过类似的题目,首先我们应该弄明白这道题是什么意思。对于矩阵的乘法大家应该是不陌生的,如果已经忘记的可以直接百度谷歌,对于矩阵的乘法存在这样的的一个特点,就是通过使用不同的组合方式来对多个矩阵进行乘积的时候,其实进行的相乘次数是不一样的。

对于两个矩阵的相乘的要求就是第一个矩阵的列数一定要等于第二矩阵的行数,这样在进行矩阵相乘时其实就是第一个矩阵的第一行中的元素分别去乘以第二个矩阵的第一列中对应的元素然后相加,这样就得到了结果矩阵的第一行的第一个元素,之后再用第一个矩阵的第一行中的元素去乘第二矩阵的第二列中对应的元素并相加,得到结果矩阵的第一行的第二个元素,依次向后进行同理,最后得到的结果矩阵应该是行数等于第一个矩阵,列数等于第二个矩阵的矩阵,也就是说两个矩阵相乘时总的相乘次数就应该为 第一个矩阵的行数 * 第一个矩阵的列数 * 第二个矩阵的列数。

更直观一点来看,比如第一个矩阵为 30x35,第二个矩阵为 35x15,那么结果矩阵就应该为 30x15,并且总的相乘次数应为 30*35*15 次。因此我们可以举一个小例子,比如 A1 = 10x100,A2 = 100x5,A3 = 5x50,假如矩阵组合的方式为 ((A1A2)A3) 那么总的相乘次数应为 10*100*5+10*5*50 = 7500,而假如矩阵的组合方式为 (A1(A2A3)) 那么相乘的次数应为 100*5*50+10*100*50 = 75000,我们可以发现两种不同的组合方式的相乘次数竟然相差这么多。

正是因为这样才出现了对于矩阵相乘的优化,这个问题其实主要就是需要解决的就是对于括号的放置,也就是对不同矩阵间的组合方式的探究,去寻找总相乘次数最少的组合方式。

对于这个问题我们其实最容易想到的就是暴力穷举法,即把所有可能的组合方式都列出来,然后一个一个的进行计算后再进行比较,最后选出最优的组合方式,但是这种方式的时间复杂度奇高。然后我们再换一种思路,其实对于穷举法,我们将所有的可能性列出来的时候,其实就会发现有很多情况是相同的,也就是我们进行了很多次重复的计算,而这些重复的情况其实是可以通过之前已经计算出的结果进行推导而直接得到的。说到这里,其实我们应该暂时有了一点点的想法,就是可不可以将之前已经计算过的记录进行保存,之后的计算以前面的计算结果为基础来进行。

对于上面的思路,其实说白了就是先计算一个一个小的情况,然后再将小的情况组合成大的情况。这样来看,其实解题的思路应该就是利用分治的一种思路,从底而上的进行计算,并将小的结果计算后保存下来,后面的计算直接依托于前记录得出。根据分治的思路,我们可以将所有矩阵的相乘分为左右两部分,然后每一部分都是已经求出的最优情况,这样就可以得到当前的最优解,即总次数为 左半部分的最优解+右半部分的最优解+两个部分的结果矩阵相乘次数。

那么分治算法最终分解到什么程度呢,我们知道对于单个矩阵是无法进行相乘计算的,所以当最后分解到只剩一个矩阵的时候,这个时候的总次数就为 0 ,然后我们就可以回退,去计算两个矩阵时的乘积了。这种思路可以通过递归去实现,但是我们会发现,如果使用递归去实现的话还是会存在重复的情况,并且还是没有很好的利用之前的结果。

这个时候我们就想到了动态规划算法,即也是利用分治的思路,将矩阵进行分裂,然后分别进行求取最优解,并在每一部分求取到最优解后进行记录。这种算法最重要的就是关于矩阵的划分,每次这个断点应该放在哪里?怎样才能保证划分后,两部分求取到的解是当前所有划分情况下的最优解。

因此,我们需要不断的去进行尝试,比如我们当前的矩阵从 i 开始,到 j 结束,那么必存在 0 <= i <= j <= n,如果 i = j,那么这个时候就为单矩阵,结果直接为 0 。而如果 i < j时,这时若 k 为 i 和 j 中的一个断点,那么整个的矩阵序列就被划分为 i 到 k 和 k+1 到 j 两部分,并且 i <= k < j,那么 k 就一定存在 i - j 种可能性,我们需要做的就是将每一种可能性都进行尝试求解,求取其中的最优划分方式,并且这时最优的结果应该就是两部分的结果相加再加上两部分结果矩阵相乘的次数,

谈使用动态规划算法解矩阵连乘问题_第1张图片

通过上面的大概分析大家应该已经有一些思路了,下面再仔细分析一下解法,首先,我们需要两个数组,一个用于记录每个区间内的最小相乘次数(比如 arr[i][j] 为从 i 到 j 中的最小相乘次数,那么假如有 n 个元素,其实整个矩阵序列的最小相乘次数就保存在 arr[0][n-1] 中),还需要一个数组来记录对于每个区间内求得最小相乘次数时的断点下标(这个下标为左半部分的结尾下标,为右半部分开始下标的前下标)。

然后对于整个序列存在两种情况,第一种就是我们上面提到的,当开始下标和结束下标相等时(arr[i][i] 这种情况),为单矩阵,也就是分治中的结束点,这个时候第一个数组中的此位置应该直接记录为 0。第二种情况就是当开始下标小于结束下标时,这个时候相乘矩阵的长度可能为 2个、3个、4个... n个,也就是会从 2个矩阵相乘的情况一直计算到 n个矩阵(矩阵总数为 n)相乘的情况。

我们就以两个矩阵相乘的情况进行举例,起始坐标的开始位置肯定是从头开始,而起始坐标的结尾位置应该为矩阵总数减去 2,因为我们当前为两个矩阵相乘,到最后时应该恰好结尾坐标位于整个矩阵序列的结尾处,然后起始坐标从结尾处回数两个位置。而结尾坐标的起始位置就是起始坐标的下标加上当前相乘矩阵的数目。举个例子,比如当前为两个矩阵相乘,那么 curLen = 2,那么起始坐标 start 的开始位置就是 0,结束位置就为 n - curLen,结尾下标为 end = start + curLen - 1。

然后我们在这个区间内再来求取断点 k ,断点的起始点肯定是 start,结束位置应为 end(注意这里取不到 end),然后对区间内的所有情况进行计算,即 arr[start][k] + arr[k+1][end] + 结果矩阵相乘次数(这里需要注意一下,这里是为了好理解,其实按照我们给定的输入,它的每三个数字代表两个矩阵,即比如 {35,30,25},对应的矩阵应为 A1 = 35x30,A2 = 30*25。所以完全写成递归式应该为 arr[start][k] + arr[k+1][end] + p[start]*p[k+1]*p[end+1]),比如当 start = 0,curLen = 2 时,end = 1,k = start = 0 那么这时当前断点开始处的计算结果应为 arr[0][0] + arr[1][1] + p[0]p[1]p[2],所以总的计算次序应如下。

谈使用动态规划算法解矩阵连乘问题_第2张图片

而按照例题中输入的数据,依次计算到最后时,两个数组的状态应如下(注意:第一个图中最右上角的数值就是当前矩阵序列的最小相乘次数,而第二个图中记录的时断点的下标,即例如最右上角的数字 2 ,指的是下标为 2 的矩阵,即真实含义为在第三个矩阵之后进行分割,推到原题就是 ((A0A1A2)(A3A4A5))。

谈使用动态规划算法解矩阵连乘问题_第3张图片

谈使用动态规划算法解矩阵连乘问题_第4张图片

五、代码实现

#include 

void Traceback(int start, int end, vector> breakPointRecordArray){
    
    // 当首尾坐标相等时说明已经递归到结尾
    // 即当前只剩下一个矩阵
    if(start == end){
        cout << "A" << start;
        return;
    }
    
    cout << "(";
    // 左半部分为从起始点到断点处
    Traceback(start, breakPointRecordArray[start][end], breakPointRecordArray);
    // 右半部分为从断点处到结尾点
    Traceback(breakPointRecordArray[start][end]+1, end, breakPointRecordArray);
    cout << ")";
    
}

void MatrixChain(vector dataArray, int len){
    
    // 相乘次数记录数组
    vector> frequencyRecordArray = vector>(len, vector(len));
    // 断点记录数组
    vector> breakPointRecordArray = vector>(len, vector(len));
    
    // 第一种情况即当相乘矩阵为单个数组时
    // 即当计算 start 到 end 中的最优断点时
    // start == end 的情况下相乘次数为 0
    for(int i = 0; i < len; ++i)
        frequencyRecordArray[i][i] = 0;
    
    // 第二种情况时当多个矩阵相乘时
    // curLen 为相乘矩阵的个数
    // 从最少的两个矩阵到最后的 N 个矩阵相乘
    for(int curLen = 2; curLen <= len; ++curLen){
        
        // 相乘矩阵的起始点从零开始到距离末尾为 curLen 个矩阵时结束
        // 比如一共有 6 个矩阵 那么当 curLen 为 2 时则计算到下标为 4 时截至
        for(int start = 0; start <= len - curLen; ++start){
            
            // 根据当前的起始点和矩阵相乘个数推断结尾点
            int end = start + curLen - 1;
            
            // 默认将当前求取点赋值为 0
            frequencyRecordArray[start][end] = INT_MAX;
            
            // 当起始点下标为 start 结束点为 end 时
            // 断点有 start - end 种可能性
            // 所以我们对每一种可能性都进行计算
            for(int breakPoint = start; breakPoint < end; ++breakPoint){
                
                // 当当前断点下标为 beakPoint 时
                // 前半部分为起始点到断点 后半部分为断点的下一个位置到结尾点
                // 我们依次求取两部分的最优解后相加 再加上这两部分的结果矩阵的相乘次数
                // 注意:因为我们提供的数据为线性数据 即每三个数据代表两个矩阵
                // 因此当求取两个矩阵的乘积时应使用:起始点 断点下一个点 结尾点
                // 来进行两个矩阵的乘积
                int curFrequency = frequencyRecordArray[start][breakPoint] + 
frequencyRecordArray[breakPoint+1][end] + dataArray[start]*dataArray[breakPoint+1]*dataArray[end+1];
                
                // 如果当前断点求出的相乘次数优于之前所求出的结果
                // 则分别将两个记录数组的数据进行更新
                if(curFrequency < frequencyRecordArray[start][end]){
                    frequencyRecordArray[start][end] = curFrequency;
                    breakPointRecordArray[start][end] = breakPoint;
                }
            }
        }
    }
    
    // 通过递归打印结果
    Traceback(0, len-1, breakPointRecordArray);
}

int main() {
    
    //注意给定数据的格式
    vector dataArray = {30, 35, 15, 5, 10, 20, 25};
    MatrixChain(dataArray, dataArray.size()-1);
}
# Python
import numpy as np
import sys


def print_optimal_parens(s, i, j):
    if i == j:
        print("A" + str(int(i+1)), end="")
    else:
        print("(", end="")
        print_optimal_parens(s, i, s[int(i), int(j)])
        print_optimal_parens(s, s[int(i), int(j)]+1, j)
        print(")", end="")


def matrix_chain(p, n):
    m = np.zeros((n, n))
    s = np.zeros((n, n))

    for len in range(2, n+1):
        for i in range(n-len+1):
            j = i + len - 1
            m[i, j] = sys.maxsize
            for k in range(i, j):
                q = m[i, k] + m[k+1, j] + p[i-1] * p[k] * p[j]
                if m[i, j] > q:
                    m[i, j] = q
                    s[i, j] = k

    print(m)
    print(s)

    print_optimal_parens(s, 0, n-1)


def main():
    n = 6
    p = [30, 35, 15, 5, 10, 20, 25]
    matrix_chain(p, n)


if __name__ == '__main__':
    main()

 

你可能感兴趣的:(算法——动态规划)