矩阵乘法Java实现

本文介绍几种方式实现矩阵相乘。矩阵概念一般语言没有提供,我们首先子句实现,同时也介绍一些现成库实现。

1. 概念介绍

首先通过示例介绍矩阵,首先定义第一个3x2的矩阵:
矩阵乘法Java实现_第1张图片

我们再定义第二个2x3的矩阵:
在这里插入图片描述

两个矩阵相乘,结果为3x4矩阵:
矩阵乘法Java实现_第2张图片

计算公式为:
矩阵乘法Java实现_第3张图片

第一个矩阵的列数要和第二个矩阵的行数相等,否则不能相乘。即从A矩阵的第一行开始依次和B矩阵的每列相乘,每行与列元素相乘的结果相加作为结果矩阵的一个元素。

2. 矩阵乘法实现

2.1 自己实现

首先我们自己实现矩阵相乘,为了简单我们使用二维double类型数组:

double[][] firstMatrix = {
  new double[]{1d, 5d},
  new double[]{2d, 3d},
  new double[]{1d, 7d}
};

double[][] secondMatrix = {
  new double[]{1d, 2d, 3d, 7d},
  new double[]{5d, 2d, 8d, 1d}
};

上面是两个矩阵示例,下面定义其相乘的结果矩阵:

double[][] expected = {
  new double[]{26d, 12d, 43d, 12d},
  new double[]{17d, 10d, 30d, 17d},
  new double[]{36d, 16d, 59d, 14d}
};

现在准备好了,让我们实现乘法算法。首先定义空结果数组,然后迭代其元素并存储每个元素的结果:

double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
    double[][] result = new double[firstMatrix.length][secondMatrix[0].length];

    for (int row = 0; row < result.length; row++) {
        for (int col = 0; col < result[row].length; col++) {
            result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
        }
    }

    return result;
}

最后我们实现每个元素的计算过程。即实现上面展示的公式:

double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
    double cell = 0;
    for (int i = 0; i < secondMatrix.length; i++) {
        cell += firstMatrix[row][i] * secondMatrix[i][col];
    }
    return cell;
}

最后,检查结果是否与期望的一致:

double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

2.2 EJML

这里介绍EJML库, 完整表述为 Efficient Java Matrix Library。该库要实现目的是在计算和内存使用方面尽可能地高效。首先引入依赖:


    org.ejml
    ejml-all
    0.38

实现逻辑与前面的一样,首先定义两个矩阵,然后检查它们乘积结果是否与期望一致。使用EJML创建矩阵,需要使用它提供的SimpleMatrix类。下面定义两个矩阵:

SimpleMatrix firstMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 5d},
    new double[] {2d, 3d},
    new double[] {1d ,7d}
  }
);

SimpleMatrix secondMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 2d, 3d, 7d},
    new double[] {5d, 2d, 8d, 1d}
  }
);

下面定义期望结果矩阵:

SimpleMatrix expected = new SimpleMatrix(
  new double[][] {
    new double[] {26d, 12d, 43d, 12d},
    new double[] {17d, 10d, 30d, 17d},
    new double[] {36d, 16d, 59d, 14d}
  }
);

现在准备好了,如何实现两个矩阵相乘。利用SimpleMatrix 类提供的 mult() 方法,参数为第二个参数,返回结果矩阵:

SimpleMatrix actual = firstMatrix.mult(secondMatrix);

SimpleMatrix 类并没有重写equals方法,因此不能直接验证。但它提供了isIdentical() 方法,它有两个参数:另一个矩阵和容错精度:

assertThat(actual).matches(m -> m.isIdentical(expected, 0d));

2.2 Colt

Colt库提供高性能的科学计算能力,需要加入依赖:


    colt
    colt
    1.2.0

利用Colt创建矩阵,需要使用DoubleFactory2D类,它提供三个工厂实例:dense, sparse 和 rowCompressed。每个用于创建特定类型矩阵。这里我们使用dense。调用make方法,带有二维double数组,返回DoubleMatrix2D 对象:

DoubleMatrix2D matrix = doubleFactory2D.make(/* a two dimensions double array */);

DoubleMatrix2D没有对应方式实现矩阵乘法,我们需要创建Algebra类实例,它提供了mult方法:

Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);

最后比较结果是否与期望一致:

assertThat(actual).isEqualTo(expected);

2.3 LA4J

LA4J表示 Linear Algebra for Java。加入依赖:


    org.la4j
    la4j
    0.6.0

LA4J 与其他库非常相似,提供了Matrix 接口及Basic2DMatrix 实现类,构造函数带二维double类型数组:

Matrix matrix = new Basic2DMatrix(/* a two dimensions double array */);

然后利用multiply方法实现乘法:

Matrix actual = firstMatrix.multiply(secondMatrix);

最后比较结果:

assertThat(actual).isEqualTo(expected);

2.4 Apache Commons

最后介绍下Apache Commons Math库,它也提供了矩阵相关运算实现,加入依赖:


    org.apache.commons
    commons-math3
    3.6.1

我们使用RealMatrix 接口及其实现类Array2DRowRealMatrix 定义矩阵。它的构造函数带有二维double类型数组:

RealMatrix matrix = new Array2DRowRealMatrix(/* a two dimensions double array */);

利用RealMatrix 提供的方法multiply实现乘法:

RealMatrix actual = firstMatrix.multiply(secondMatrix);
assertThat(actual).isEqualTo(expected);

3. 总结

本文介绍矩阵乘法实现,矩阵乘法在层次分析法中有应用。通过自我实现理解其原理,接着又介绍了几种现有库的实现,读者也可以自己通过基准测试它们性能差异。

你可能感兴趣的:(java8~9核心功能,大数据处理,工具软件,矩阵乘法)