分享矩阵乘法单线程与多线程的Java实现与效率对比,请教Strassen算法

分享矩阵乘法单线程与多线程的Java实现与效率对比,请教Strassen算法

矩阵乘法的多线程实现:

/**    
* @Title: MultiThreadMatrix.java 
* @Package matrix 
* @Description: 多线程计算矩阵乘法 
@author Aloong 
* @date 2010-10-28 下午09:45:56 
@version V1.0 
*/
 

package  matrix;

import  java.util.Date;


public   class  MultiThreadMatrix
{
    
    
static int[][] matrix1;
    
static int[][] matrix2;
    
static int[][] matrix3;
    
static int m,n,k;
    
static int index;
    
static int threadCount;
    
static long startTime;
    
    
public static void main(String[] args) throws InterruptedException
    
{
        
//矩阵a高度m=100宽度k=80,矩阵b高度k=80宽度n=50 ==> 矩阵c高度m=100宽度n=50
        m = 1024;
        n 
= 1024;
        k 
= 1024;
        matrix1 
= new int[m][k];
        matrix2 
= new int[k][n];
        matrix3 
= new int[m][n];
        
        
//随机初始化矩阵a,b
        fillRandom(matrix1);
        fillRandom(matrix2);
        startTime 
= new Date().getTime();
        
        
//输出a,b
//        printMatrix(matrix1);
//        printMatrix(matrix2);
        
        
//创建线程,数量 <= 4
        for(int i=0; i<4; i++)
        
{
            
if(index < m)
            
{
                Thread t 
= new Thread(new MyThread());
                t.start();
            }
else 
            
{
                
break;
            }

        }

        
        
//等待结束后输出
        while(threadCount!=0)
        
{
            Thread.sleep(
20);
        }

        
//        printMatrix(matrix3);
        long finishTime = new Date().getTime();
        System.out.println(
"计算完成,用时"+(finishTime-startTime)+"毫秒");
    }

    
    
static void printMatrix(int[][] x)
    
{
        
for (int i=0; i<x.length; i++)
        
{
            
for(int j=0; j<x[i].length; j++)
            
{
                System.out.print(x[i][j]
+" ");
            }

            System.out.println(
"");
        }

        System.out.println(
"");
    }

    
    
static void fillRandom(int[][] x)
    
{
        
for (int i=0; i<x.length; i++)
        
{
            
for(int j=0; j<x[i].length; j++)
            
{
                
//每个元素设置为0到99的随机自然数
                x[i][j] = (int) (Math.random() * 100);
            }

        }

    }


    
synchronized static int getTask()
    
{
        
if(index < m)
        
{
            
return index++;
        }

        
return -1;
    }


}


class  MyThread  implements  Runnable
{
    
int task;
    @Override
    
public void run()
    
{
        MultiThreadMatrix.threadCount
++;
        
while( (task = MultiThreadMatrix.getTask()) != -1 )
        
{
            System.out.println(
"进程: "+Thread.currentThread().getName()+"\t开始计算第 "+(task+1)+"");
            
for(int i=0; i<MultiThreadMatrix.n; i++)
            
{
                
for(int j=0; j<MultiThreadMatrix.k; j++)
                
{
                    MultiThreadMatrix.matrix3[task][i] 
+= MultiThreadMatrix.matrix1[task][j] * MultiThreadMatrix.matrix2[j][i];
                }

            }

        }

        MultiThreadMatrix.threadCount
--;
    }

}



单线程:

/**    
* @Title: SingleThreadMatrix.java 
* @Package matrix 
* @Description: 单线程计算矩阵乘法 
@author Aloong 
* @date 2010-10-28 下午11:33:18 
@version V1.0 
*/
 

package  matrix;

import  java.util.Date;


public   class  SingleThreadMatrix
{
    
static int[][] matrix1;
    
static int[][] matrix2;
    
static int[][] matrix3;
    
static int m,n,k;
    
static long startTime;
    
    
public static void main(String[] args)
    
{
        m 
= 1024;
        n 
= 1024;
        k 
= 1024;
        matrix1 
= new int[m][k];
        matrix2 
= new int[k][n];
        matrix3 
= new int[m][n];
        
        fillRandom(matrix1);
        fillRandom(matrix2);
        startTime 
= new Date().getTime();
        
        
//输出a,b
//        printMatrix(matrix1);
//        printMatrix(matrix2);
        

        
        
for(int task=0; task<m; task++)
        
{
            System.out.println(
"进程: "+Thread.currentThread().getName()+"\t开始计算第 "+(task+1)+"");
            
for(int i=0; i<n; i++)
            
{
                
for(int j=0; j<k; j++)
                
{
                    matrix3[task][i] 
+= matrix1[task][j] * matrix2[j][i];
                }

            }

        }

        
//        printMatrix(matrix3);
        long finishTime = new Date().getTime();
        System.out.println(
"计算完成,用时"+(finishTime-startTime)+"毫秒");
    }


    
static void fillRandom(int[][] x)
    
{
        
for (int i=0; i<x.length; i++)
        
{
            
for(int j=0; j<x[i].length; j++)
            
{
                
//每个元素设置为0到99的随机自然数
                x[i][j] = (int) (Math.random() * 100);
            }

        }

    }

}


修改m,n,k的值可以修改相乘矩阵的阶数.

结果对比,计算1024阶矩阵的时候多线程用时约4.8秒,单线程用时16秒,
单线程占用内存21M,多线程占用16M.
本机是4核CPU,单线程的时候只有25%的CPU占用,使用4个子线程可以达到接近100%的CPU使用率.


另外请教一个问题,是矩阵乘法的Strassen算法
下面这个是来自网上的一段代码,在我自己的电脑上,只要超过12阶就会内存溢出
不解是什么原因,设置jvm的内存不管多大也会崩溃在12阶
请高手帮忙解答....


package  matrix;

import  java.io. * ;
import  java.util. * ;

class  Matrix  // 定义矩阵结构    
{
    
public int[][] m = new int[32][32];
}


public   class  StrassenMatrix2
{
    
public int IfIsEven(int n)//判断输入矩阵阶数是否为2^k    
    {
        
int a = 0, temp = n;
        
while (temp % 2 == 0)
        
{
            
if (temp % 2 == 0)
                temp 
/= 2;
            
else
                a 
= 1;
        }

        
if (temp == 1)
            a 
= 0;
        
return a;
    }


    
public void Divide(Matrix d, Matrix d11, Matrix d12, Matrix d21, Matrix d22, int n)//分解矩阵    
    {
        
int i, j;
        
for (i = 1; i <= n; i++)
            
for (j = 1; j <= n; j++)
            
{
                d11.m[i][j] 
= d.m[i][j];
                d12.m[i][j] 
= d.m[i][j + n];
                d21.m[i][j] 
= d.m[i + n][j];
                d22.m[i][j] 
= d.m[i + n][j + n];
            }

    }


    
public Matrix Merge(Matrix a11, Matrix a12, Matrix a21, Matrix a22, int n)//合并矩阵    
    {
        
int i, j;
        Matrix a 
= new Matrix();
        
for (i = 1; i <= n; i++)
            
for (j = 1; j <= n; j++)
            
{
                a.m[i][j] 
= a11.m[i][j];
                a.m[i][j 
+ n] = a12.m[i][j];
                a.m[i 
+ n][j] = a21.m[i][j];
                a.m[i 
+ n][j + n] = a22.m[i][j];
            }

        
return a;
    }


    
public Matrix TwoMatrixMultiply(Matrix x, Matrix y) //阶数为2的矩阵乘法    
    {
        
int m1, m2, m3, m4, m5, m6, m7;
        Matrix z 
= new Matrix();

        m1 
= (y.m[1][2- y.m[2][2]) * x.m[1][1];
        m2 
= y.m[2][2* (x.m[1][1+ x.m[1][2]);
        m3 
= (x.m[2][1+ x.m[2][2]) * y.m[1][1];
        m4 
= x.m[2][2* (y.m[2][1- y.m[1][1]);
        m5 
= (x.m[1][1+ x.m[2][2]) * (y.m[1][1+ y.m[2][2]);
        m6 
= (x.m[1][2- x.m[2][2]) * (y.m[2][1+ y.m[2][2]);
        m7 
= (x.m[1][1- x.m[2][1]) * (y.m[1][1+ y.m[1][2]);
        z.m[
1][1= m5 + m4 - m2 + m6;
        z.m[
1][2= m1 + m2;
        z.m[
2][1= m3 + m4;
        z.m[
2][2= m5 + m1 - m3 - m7;
        
return z;
    }


    
public Matrix MatrixPlus(Matrix f, Matrix g, int n) //矩阵加法    
    {
        
int i, j;
        Matrix h 
= new Matrix();
        
for (i = 1; i <= n; i++)
            
for (j = 1; j <= n; j++)
                h.m[i][j] 
= f.m[i][j] + g.m[i][j];
        
return h;
    }


    
public Matrix MatrixMinus(Matrix f, Matrix g, int n) //矩阵减法方法    
    {
        
int i, j;
        Matrix h 
= new Matrix();
        
for (i = 1; i <= n; i++)
            
for (j = 1; j <= n; j++)
                h.m[i][j] 
= f.m[i][j] - g.m[i][j];
        
return h;
    }


    
public Matrix MatrixMultiply(Matrix a, Matrix b, int n) //矩阵乘法方法    
    {
        
int k;
        Matrix a11, a12, a21, a22;
        a11 
= new Matrix();
        a12 
= new Matrix();
        a21 
= new Matrix();
        a22 
= new Matrix();
        Matrix b11, b12, b21, b22;
        b11 
= new Matrix();
        b12 
= new Matrix();
        b21 
= new Matrix();
        b22 
= new Matrix();
        Matrix c11, c12, c21, c22, c;
        c11 
= new Matrix();
        c12 
= new Matrix();
        c21 
= new Matrix();
        c22 
= new Matrix();
        c 
= new Matrix();
        Matrix m1, m2, m3, m4, m5, m6, m7;
        k 
= n;
        
if (k == 2)
        
{
            c 
= TwoMatrixMultiply(a, b);
            
return c;
        }
 else
        
{
            k 
= n / 2;
            Divide(a, a11, a12, a21, a22, k); 
//拆分A、B、C矩阵    
            Divide(b, b11, b12, b21, b22, k);
            Divide(c, c11, c12, c21, c22, k);

            m1 
= MatrixMultiply(a11, MatrixMinus(b12, b22, k), k);
            m2 
= MatrixMultiply(MatrixPlus(a11, a12, k), b22, k);
            m3 
= MatrixMultiply(MatrixPlus(a21, a22, k), b11, k);
            m4 
= MatrixMultiply(a22, MatrixMinus(b21, b11, k), k);
            m5 
= MatrixMultiply(MatrixPlus(a11, a22, k),
                    MatrixPlus(b11, b22, k), k);
            m6 
= MatrixMultiply(MatrixMinus(a12, a22, k),
                    MatrixPlus(b21, b22, k), k);
            m7 
= MatrixMultiply(MatrixMinus(a11, a21, k),
                    MatrixPlus(b11, b12, k), k);
            c11 
= MatrixPlus(MatrixMinus(MatrixPlus(m5, m4, k), m2, k), m6, k);
            c12 
= MatrixPlus(m1, m2, k);
            c21 
= MatrixPlus(m3, m4, k);
            c22 
= MatrixMinus(MatrixMinus(MatrixPlus(m5, m1, k), m3, k), m7, k);

            c 
= Merge(c11, c12, c21, c22, k); //合并C矩阵    
            return c;
        }

    }


    
public Matrix GetMatrix(Matrix X, int n)
    
{
        
int i, j;
        X 
= new Matrix();
        
for (i = 1; i <= n; i++)
            
for (j = 1; j <= n; j++)
                X.m[i][j] 
= (int) (Math.random() * 10);
        
for (i = 1; i <= n; i++)
        
{
            
for (j = 1; j <= n; j++)
                System.out.print(X.m[i][j] 
+ " ");
            System.out.println();
        }

        
return X;
    }


    
public Matrix UsualMatrixMultiply(Matrix A, Matrix B, Matrix C, int n)
    
{
        
int i, j, t, k;
        
for (i = 1; i <= n; i++)
            
for (j = 1; j <= n; j++)
            
{
                
for (k = 1, t = 0; k <= n; k++)
                    t 
+= A.m[i][k] * B.m[k][j];
                C.m[i][j] 
= t;
            }

        
return C;
    }


    
public static void main(String[] args) throws IOException
    
{
        StrassenMatrix2 instance 
= new StrassenMatrix2();
        
int i, j, n;
//        Matrix A, B, C, D;
        Matrix A, B, C;
        A 
= new Matrix();
        B 
= new Matrix();
        C 
= new Matrix();
//        D = new matrix();
        Scanner in = new Scanner(System.in);
        System.out.print(
"输入矩阵的阶数: ");
        n 
= in.nextInt();
        
if (instance.IfIsEven(n) == 0)
        
{
            System.out.println(
"矩阵A:");
            A 
= instance.GetMatrix(A, n);
            System.out.println(
"矩阵B:");
            B 
= instance.GetMatrix(B, n);
            
if (n == 1)
                C.m[
1][1= A.m[1][1* B.m[1][1]; //矩阵阶数为1时的特殊处理     
            else
            
{
                
long startTime = new Date().getTime();
                C 
= instance.MatrixMultiply(A, B, n);
                
long finishTime = new Date().getTime();
                System.out.println(
"计算完成,用时"+(finishTime-startTime)+"毫秒");
            }

            System.out.println(
"Strassen矩阵C为:");
            
for (i = 1; i <= n; i++)
            
{
                
for (j = 1; j <= n; j++)
                    System.out.print(C.m[i][j] 
+ " ");
                System.out.println();
            }

            
/*            D = instance.UsualMatrixMultiply(A, B, D, n);
            System.out.println("普通乘法矩阵D为:");
            for (i = 1; i <= n; i++)
            {
                for (j = 1; j <= n; j++)
                    System.out.print(D.m[i][j] + " ");
                System.out.println();
            }
*/

        }
 else
            System.out.println(
"输入的阶数不是2的N次方");
    }

}
 

你可能感兴趣的:(分享矩阵乘法单线程与多线程的Java实现与效率对比,请教Strassen算法)