Strassen算法也是一种基于分治思想的算法。首先我们用普通的分治方法来实现矩阵的乘法。
这里我是用行下标和列下标来拆分矩阵的,并没有去复制矩阵中的元素,和算法导论中给出的思路是一样的。
注意:以下的程序只适用N*N的矩阵。
首先,我定义一个关于矩阵的对象,Java代码如下:
package hanxl.insist.beans;
/** * 矩阵类 * 之所以设置起始和结束下标是因为会去用下标partition矩阵 */
public class Matrix {
/** * 矩阵的行起始下标 * 用于partition矩阵 */
private int rowStartIndex;
/** * 矩阵的行结束下标 * 用于partition矩阵 */
private int rowEndIndex;
/** * 矩阵的列起始下标 * 用于partition矩阵 */
private int columnStartIndex;
/** * 矩阵的列结束下标 * 用于partition矩阵 */
private int columnEndIndex;
/** * 矩阵中的数元素 */
private int[][] elements;
/** * 根据给定的二维数组构建一个矩阵 * @param elements */
public Matrix(int[][] elements) {
this(0, elements.length - 1, 0, elements[0].length - 1, elements);
}
/** * 根据给定的4个下标来拆分给定的数组元素,并构建矩阵,矩阵中的元素与数组中的元素不一定一致 * @param rowStartIndex * @param rowEndIndex * @param columnStartIndex * @param columnEndIndex * @param elements */
public Matrix(int rowStartIndex, int rowEndIndex, int columnStartIndex, int columnEndIndex, int[][] elements) {
this.rowStartIndex = rowStartIndex;
this.rowEndIndex = rowEndIndex;
this.columnStartIndex = columnStartIndex;
this.columnEndIndex = columnEndIndex;
this.elements = elements;
}
/** * 构造一个指定行和列的空矩阵 * @param row * @param column */
public Matrix(int row, int column) {
this(new int[row][column]);
}
public static Matrix add(Matrix a, Matrix b) {
Matrix matrix = new Matrix(a.getRows(),a.getColumns());
int[][] resultElements = matrix.getElements();
int[][] aelements = a.getElements();
int[][] belements = b.getElements();
for (int i = 0; i < belements.length; i++)
for (int j = 0; j < belements.length; j++)
resultElements[i][j] = aelements[i][j] + belements[i][j];
return matrix;
}
/** * 根据当前矩阵的4个下标打印矩阵 */
public void printMatrix() {
int[][] e = this.getElements();
for (int i = this.getRowStartIndex(); i <= this.getRowEndIndex(); i++) {
for (int j = this.getColumnStartIndex(); j <= this.getColumnEndIndex(); j++) {
System.out.print(e[i][j] + " ");
}
System.out.println();
}
}
/** * 获取矩阵的行 */
public int getRows() {
return rowEndIndex - rowStartIndex + 1;
}
/** * 获取矩阵的列 */
public int getColumns() {
return columnEndIndex - columnStartIndex + 1;
}
public int getRowStartIndex() {
return rowStartIndex;
}
public void setRowStartIndex(int rowStartIndex) {
this.rowStartIndex = rowStartIndex;
}
public int getRowEndIndex() {
return rowEndIndex;
}
public void setRowEndIndex(int rowEndIndex) {
this.rowEndIndex = rowEndIndex;
}
public int getColumnStartIndex() {
return columnStartIndex;
}
public void setColumnStartIndex(int columnStartIndex) {
this.columnStartIndex = columnStartIndex;
}
public int getColumnEndIndex() {
return columnEndIndex;
}
public void setColumnEndIndex(int columnEndIndex) {
this.columnEndIndex = columnEndIndex;
}
public int[][] getElements() {
return elements;
}
public void setElements(int[][] elements) {
this.elements = elements;
}
}
关于分治思想的代码如下:
package hanxl.insist.fourchapter;
import hanxl.insist.beans.Matrix;
public class MatrixMultiply {
public static void main(String[] args) {
int[][] aelements = {{1, 3,4,5}, {7,2,6, 5},{2,4,2, 4}, {6,4,8,2}};
Matrix a = new Matrix(aelements);
a.printMatrix();
System.out.println("-------------------");
int[][] belements = {{6,4,5 ,8}, {4,2,3 ,2}, {1,1,1,3}, {6,9,2,7}};
Matrix b = new Matrix(belements);
b.printMatrix();
System.out.println("-------------------");
Matrix r = recursiveMultiply(a, b);
r.printMatrix();
}
public static Matrix recursiveMultiply(Matrix a, Matrix b) {
Matrix c = new Matrix(a.getRows(), a.getColumns()); //根据矩阵a的起始下标......等创建的,并不是用下标约束,而是一个真实的矩阵
if ( c.getRows() == 1 )
c.getElements()[0][0] = a.getElements()[a.getRowStartIndex()][a.getColumnStartIndex()] * b.getElements()[b.getRowStartIndex()][b.getColumnStartIndex()]; // base case
else {
Matrix[] amatrixs = partition(a);
Matrix a11 = amatrixs[0];
Matrix a12 = amatrixs[1];
Matrix a21 = amatrixs[2];
Matrix a22 = amatrixs[3];
Matrix[] bmatrixs = partition(b);
Matrix b11 = bmatrixs[0];
Matrix b12 = bmatrixs[1];
Matrix b21 = bmatrixs[2];
Matrix b22 = bmatrixs[3];
Matrix[] cmatrixs = partition(c); //这些小矩阵的elements对象是与c一样的,只不过用下标将其限制住了,它们和c相比也不是同一个对象
Matrix c11 = cmatrixs[0];
Matrix c12 = cmatrixs[1];
Matrix c21 = cmatrixs[2];
Matrix c22 = cmatrixs[3];
c11 = Matrix.add(recursiveMultiply(a11, b11),recursiveMultiply(a12, b21));
c12 = Matrix.add(recursiveMultiply(a11, b12),recursiveMultiply(a12, b22));
c21 = Matrix.add(recursiveMultiply(a21, b11),recursiveMultiply(a22, b21));
c22 = Matrix.add(recursiveMultiply(a21, b12),recursiveMultiply(a22, b22));
c = merge(c11, c12, c21, c22);
}
return c;
}
/** * 把4个小矩阵合并成一个大矩阵 * @param c11 * @param c12 * @param c21 * @param c22 * @return */
public static Matrix merge(Matrix c11, Matrix c12, Matrix c21, Matrix c22) {
Matrix matrix = new Matrix(c11.getRows() * 2, c11.getColumns() * 2);
int[][] elements = matrix.getElements();
int length = c11.getElements().length;
for (int i = 0; i < length; i++) {
for (int j = 0; j < length; j++) {
elements[i][j] = c11.getElements()[i][j];
elements[i][j + length] = c12.getElements()[i][j];
elements[i + length][j] = c21.getElements()[i][j];
elements[i + length][j + length] = c22.getElements()[i][j];
}
}
return matrix;
}
/** * 把一个大矩阵切分成四个小矩阵封装到数组之中 * @param matrix * @return */
public static Matrix[] partition( Matrix matrix ) {
Matrix[] matrixs = new Matrix[4];
int rowStart = matrix.getRowStartIndex();
int rowEnd = matrix.getRowEndIndex();
int rowMid = ( rowStart + rowEnd ) / 2;
int[][] elements = matrix.getElements();
int columnStart = matrix.getColumnStartIndex();
int columnEnd = matrix.getColumnEndIndex();
int columnMid = ( columnStart + columnEnd ) / 2;
matrixs[0] = new Matrix(rowStart, rowMid, columnStart, columnMid, elements);
matrixs[1] = new Matrix(rowStart, rowMid, columnMid + 1, columnEnd, elements);
matrixs[2] = new Matrix(rowMid + 1, rowEnd, columnStart, columnMid, elements);
matrixs[3] = new Matrix(rowMid + 1, rowEnd, columnMid + 1, columnEnd, elements);
return matrixs;
}
}
Strassen算法Java代码如下:
package hanxl.insist.fourchapter;
import hanxl.insist.beans.Matrix;
public class Strassen {
public static void main(String[] args) {
int[][] aelements = { { 1, 3, 4, 5 }, { 7, 2, 6, 5 }, { 2, 4, 2, 4 }, { 6, 4, 8, 2 } }; //
Matrix a = new Matrix(aelements); //
int[][] belements = { { 6, 4, 5, 8 }, { 4, 2, 3, 2 }, { 1, 1, 1, 3 }, { 6, 9, 2, 7 } }; //
Matrix b = new Matrix(belements); //
Matrix r = recursiveMultiply(a, b); //
System.out.println("----这是结果----");
r.printMatrix(); //
}
public static Matrix recursiveMultiply(Matrix a, Matrix b) {
Matrix c = new Matrix(a.getRows(), a.getColumns());
if (c.getRows() == 1)
c.getElements()[0][0] = a.getElements()[a.getRowStartIndex()][a.getColumnStartIndex()]
* b.getElements()[b.getRowStartIndex()][b.getColumnStartIndex()]; // base
// case
else {
Matrix[] amatrixs = MatrixMultiply.partition(a);
Matrix a11 = amatrixs[0];
Matrix a12 = amatrixs[1];
Matrix a21 = amatrixs[2];
Matrix a22 = amatrixs[3];
Matrix[] bmatrixs = MatrixMultiply.partition(b);
Matrix b11 = bmatrixs[0];
Matrix b12 = bmatrixs[1];
Matrix b21 = bmatrixs[2];
Matrix b22 = bmatrixs[3];
Matrix s1 = calculate(b12, b22, "-"); // s1为堂堂正正的一个矩阵,并没有用下标限制
Matrix s2 = calculate(a11, a12, "+");
Matrix s3 = calculate(a21, a22, "+");
Matrix s4 = calculate(b21, b11, "-");
Matrix s5 = calculate(a11, a22, "+");
Matrix s6 = calculate(b11, b22, "+");
Matrix s7 = calculate(a12, a22, "-");
Matrix s8 = calculate(b21, b22, "+");
Matrix s9 = calculate(a11, a21, "-");
Matrix s10 = calculate(b11, b12, "+");
Matrix p1 = recursiveMultiply(a11, s1);
Matrix p2 = recursiveMultiply(s2, b22);
Matrix p3 = recursiveMultiply(s3, b11);
Matrix p4 = recursiveMultiply(a22, s4);
Matrix p5 = recursiveMultiply(s5, s6);
Matrix p6 = recursiveMultiply(s7, s8);
Matrix p7 = recursiveMultiply(s9, s10);
Matrix c11 = calculate(calculate(p5, p4, "+"), calculate(p6, p2, "-"), "+");
Matrix c12 = calculate(p1, p2, "+");
Matrix c21 = calculate(p3,p4, "+");
Matrix c22 = calculate(calculate(p5, p1, "+"), calculate(p7, p3, "+"), "-");
c = MatrixMultiply.merge(c11, c12, c21, c22);
}
return c;
}
private static Matrix calculate(Matrix b12, Matrix b22, String operator) {
Matrix matrix = new Matrix(b12.getRows(), b12.getColumns());
int[][] resultElements = matrix.getElements();
int rp = 0;
int cp = 0;
int[][] aelements = b12.getElements();
int[][] belements = b22.getElements();
int brp = b22.getRowStartIndex();
int bcp = b22.getColumnStartIndex();
for (int i = b12.getRowStartIndex(); i <= b12.getRowEndIndex(); i++) {
for (int j = b12.getColumnStartIndex(); j <= b12.getColumnEndIndex(); j++) {
if ("-".equals(operator))
resultElements[rp][cp] = aelements[i][j] - belements[brp][bcp];
else
resultElements[rp][cp] = aelements[i][j] + belements[brp][bcp];
bcp++;
cp++;
}
cp = 0;
bcp = b22.getColumnStartIndex();
brp++;
rp++;
}
return matrix;
}
}
这个算法中的calculate方法之所以不用Matrix对象中的add方法,是因为这个方法的参数是被用下标限制的矩阵,而那个add方法并没有做出任何限制,就是和二维数组是一样的。