动态规划实例:矩阵连乘

矩阵连乘

      • 问题描述
      • 算法设计思路
      • 实现代码
      • 完整代码实现
      • 算法时间复杂度

问题描述

       采用标准的矩阵乘法来计算M1、M2和M3三个矩阵的乘积M1M2M3,设这三个矩阵的维数分别是2 × 10、10 × 2和2 × 10。如果先把M1和M2相乘,然后把结果和M3相乘,那么要进行2 × 10 × 2 + 2 × 2 × 10 = 80次乘法;如果代之用M2和M3相乘的结果去乘M1,那么数量乘法的次数为10 × 2 × 10 + 2 × 10 × 10 = 400。显然,执行乘法M1(M2M3)耗费的时间是执行乘法(M1M2)M3的5倍。一般来说,n个矩阵链M1M2……Mn乘法的耗费,取决于n – 1个乘法执行的顺序。请设计一个动态规划算法,使得计算矩阵链乘法时需要的数量乘法次数达到最小。

算法设计思路

       因为子问题有重复的,而且最终问题可以由几个子问题组成,所以可以采用动态规划,自底向上的备忘录法。矩阵A1:2 x 3 ,矩阵A2:3 x 4,则乘积次数:2 x 3 x 4。
递推式:
当ii-1pkpj)
当i=j,即矩阵个数为1时, m[i][j]=0。
其中,k为断开位置,i为前边界,j为后边界 pi-1为第一个矩阵的行,pk断开位置的列(或下一个的行),pj为最后一个矩阵的列。

实现代码

private static void martrixChain(int[] p,int n,int[][] m,int[][] s) {
		for(int i = 1;i <= n;i++)//单个矩阵赋初值
			m[i][i] = 0;
		for(int r = 2;r <= n;r++) {//子问题规模
			for(int i = 1;i <= n-r+1;i++) {//子问题的前边界
				int j = i+r-1;			   //后边界
				s[i][j] = i;			   //存储断开位置
				m[i][j] = m[i+1][j] + p[i-1]*p[i]*p[j];//计算(A1)*(A2A3A4...)的乘积
				for(int k = i+1;k < j;k++) {		   //由于前面断开位置i已计算,则断开位置k从i+1开始
					int t = m[i][k] + m[k+1][j] + p[i-1]*p[k]*p[j];//计算其他矩阵乘积
					if(t < m[i][j]) {								//若<之前计算的值,则覆盖,并记录断开位置k
						m[i][j] = t;							
						s[i][j] = k;
					}
				}
			}
		}
	}

完整代码实现

import java.util.Scanner;

public class Main {
	public static void main(String[] args) {
		Scanner in = new Scanner(System.in);
		int n = in.nextInt();
		int[] p = new int[n+1];
		for(int i = 0;i <= n;i++) {
			p[i] = in.nextInt();
		}
		int[][] m = new int[n+1][n+1];
		int[][] s = new int[n+1][n+1];
		martrixChain(p, n, m, s);
		traceBack(1, n, s);
		System.out.println("乘法次数:"+m[1][n]);
	}

	private static void martrixChain(int[] p,int n,int[][] m,int[][] s) {
		for(int i = 1;i <= n;i++)//单个矩阵赋初值
			m[i][i] = 0;
		for(int r = 2;r <= n;r++) {//子问题规模
			for(int i = 1;i <= n-r+1;i++) {//子问题的前边界
				int j = i+r-1;			   //后边界
				s[i][j] = i;			   //存储断开位置
				m[i][j] = m[i+1][j] + p[i-1]*p[i]*p[j];//计算(A1)*(A2A3A4...)的乘积
				for(int k = i+1;k < j;k++) {		   //由于前面断开位置i已计算,则断开位置k从i+1开始
					int t = m[i][k] + m[k+1][j] + p[i-1]*p[k]*p[j];//计算其他矩阵乘积
					if(t < m[i][j]) {								//若<之前计算的值,则覆盖,并记录断开位置k
						m[i][j] = t;							
						s[i][j] = k;
					}
				}
			}
		}
	}
	//构造最优解
	static void traceBack(int i, int j, int[][] s) {  
	    if(i == j) return; //当只有一个矩阵时,不需要相乘
	    traceBack(i, s[i][j], s); //分为两边,左边:i~s[i][j]的断开位置
	    traceBack(s[i][j]+1,  j, s);//		    右边:s[i][j]+1~j
	    System.out.print("Multiply A"+i+","+s[i][j]);
	    System.out.println(" and A"+(s[i][j]+1)+","+j);
	}
}

算法时间复杂度

       matrixChain的主要计算量取决于算法中对r,i和k的3重循环:循环体内的计算量为O(1),而3重循环的总次数为O(n3)

你可能感兴趣的:(算法设计)