操作系统、strassen算法

不难
原理写的也很清楚了,确实巧妙

from threading import Thread
import time
import random

'''
:function SP_Matrix: 作用是把矩阵A分解成四个4个n/2×n/2的子矩阵。
:function Merge_Matrix: 作用是把四个4个n/2×n/2的子矩阵合并为一个n×n的矩阵。
:function Add_Matrix: 作用是计算矩阵A和B的加。
=》算法流程
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12

=》接着,计算7次矩阵乘法:
P1 = A11 • S1
P2 = S2 • B22
P3 = S3 • B11
P4 = A22 • S4
P5 = S5 • S6
P6 = S7 • S8
P7 = S9 • S10

=》最后,根据这7个结果就可以计算出C矩阵:
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7

'''

class BasicThread(Thread):
    def __init__(self, func, args):
        '''
        :param func: 调用的对象
        :param args: 调用对象的参数
        '''
        Thread.__init__(self)
        self.func = func
        self.args = args
        self.result = None

    def run(self):
        self.result = self.func(*self.args)

def Strassen_Matrix(A, B):
    n_row = len(A)
    n_column = len(A[0])
    n = min(n_column, n_row)
    C = [[0 for col in range(n)] for row in range(n)]
    if n_row == 1:
        for i in range(n_column):
            C[0][0] += A[0][i]*B[0][i]
    elif n_column == 1:
        for i in range(n_row):
            C[0][0] += A[i][0]*B[i][0]
    else:
        (A11, A12, A21, A22) = SP_matrix(A)
        (B11, B12, B21, B22) = SP_matrix(B)
        # Semi Parameter
        S1 = BasicThread(func=Sub_Matrix, args=(B12, B22))
        S2 = BasicThread(func=Add_Matrix, args=(A11, A12))
        S3 = BasicThread(func=Add_Matrix, args=(A21, A22))
        S4 = BasicThread(func=Sub_Matrix, args=(B21, B11))
        S5 = BasicThread(func=Add_Matrix, args=(A11, A22))
        S6 = BasicThread(func=Add_Matrix, args=(B11, B22))
        S7 = BasicThread(func=Sub_Matrix, args=(A12, A22))
        S8 = BasicThread(func=Add_Matrix, args=(B21, B22))
        S9 = BasicThread(func=Sub_Matrix, args=(A11, A21))
        S10 = BasicThread(func=Add_Matrix, args=(B11, B12))
        S1.start()
        S2.start()
        S3.start()
        S4.start()
        S5.start()
        S6.start()
        S7.start()
        S8.start()
        S9.start()
        S10.start()
        S1.join()
        S2.join()
        S3.join()
        S4.join()
        S5.join()
        S6.join()
        S7.join()
        S8.join()
        S9.join()
        S10.join()

        # Multiplication Steeps
        P1 = BasicThread(func=Strassen_Matrix, args=(A11,S1.result))
        P2 = BasicThread(func=Strassen_Matrix, args=(S2.result,B22))
        P3 = BasicThread(func=Strassen_Matrix, args=(S3.result,B11))
        P4 = BasicThread(func=Strassen_Matrix, args=(A22,S4.result))
        P5 = BasicThread(func=Strassen_Matrix, args=(S5.result,S6.result))
        P6 = BasicThread(func=Strassen_Matrix, args=(S7.result,S8.result))
        P7 = BasicThread(func=Strassen_Matrix, args=(S9.result,S10.result))
        P1.start()
        P2.start()
        P3.start()
        P4.start()
        P5.start()
        P6.start()
        P7.start()
        P1.join()
        P2.join()
        P3.join()
        P4.join()
        P5.join()
        P6.join()
        P7.join()

        # Calculate C
        C11A = BasicThread(func=Add_Matrix, args=(P5.result, P4.result))
        C11B = BasicThread(func=Sub_Matrix, args=(P6.result, P2.result))
        C12 = BasicThread(func=Add_Matrix, args=(P1.result, P2.result))
        C21 = BasicThread(func=Add_Matrix, args=(P3.result, P4.result))
        C22A = BasicThread(func=Sub_Matrix, args=(P5.result, P3.result))
        C22B = BasicThread(func=Sub_Matrix, args=(P1.result, P7.result))
        C11A.start()
        C11B.start()
        C12.start()
        C21.start()
        C22A.start()
        C22B.start()
        C11A.join()
        C11B.join()
        C12.join()
        C21.join()
        C22A.join()
        C22B.join()

        C = Merge_Matrix(Add_Matrix(C11A.result,C11B.result),C12.result,C21.result,Add_Matrix(C22A.result,C22B.result))

    return C

def SP_matrix(A):
    n_row = len(A)
    n_column = len(A[0])
    n2_row = int(n_row / 2)
    n2_column = int(n_column / 2)
    A11 = [[0 for col in range(n2_column)] for row in range(n2_row)]
    A12 = [[0 for col in range(n2_column)] for row in range(n2_row)]
    A21 = [[0 for col in range(n2_column)] for row in range(n2_row)]
    A22 = [[0 for col in range(n2_column)] for row in range(n2_row)]
    for i in range(0, n2_row):
        for j in range(0, n2_column):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][j + n2_column]
            A21[i][j] = A[i + n2_row][j]
            A22[i][j] = A[i + n2_row][j + n2_column]
    return (A11, A12, A21, A22)


def Merge_Matrix(A11, A12, A21, A22):
    n2 = len(A11)
    n = 2 * n2
    A = [[0 for col in range(n)] for row in range(n)]
    for i in range(0, n):
        for j in range(0, n):
            if i <= (n2 - 1) and j <= (n2 - 1):
                A[i][j] = A11[i][j]

            elif i <= (n2 - 1) and j > (n2 - 1):
                A[i][j] = A12[i][j - n2]
            elif i > (n2 - 1) and j <= (n2 - 1):
                A[i][j] = A21[i - n2][j]
            else:
                A[i][j] = A22[i - n2][j - n2]
    return A

def Add_Matrix(A, B):
    n_row = len(A)
    n_column = len(A[0])
    C = [[0 for col in range(n_column)] for row in range(n_row)]
    for i in range(0, n_row):
        for j in range(0, n_column):
            C[i][j] = A[i][j] + B[i][j]
    return C

def Sub_Matrix(A, B):
    n_row = len(A)
    n_column = len(A[0])
    C = [[0 for col in range(n_column)] for row in range(n_row)]
    for i in range(0, n_row):
        for j in range(0, n_column):
            C[i][j] = A[i][j] - B[i][j]
    return C

if __name__ == "__main__":
    start = time.clock()
    A = [[random.random() for i in range(64)] for j in range(64)]
    B = [[random.random() for i in range(64)] for j in range(64)]
    T_main = BasicThread(func=Strassen_Matrix, args=(A, B))
    T_main.start()
    T_main.join()
    end = time.clock()
    print("数组A:")
    print(A)
    print("数组B:")
    print(B)
    print("结果:")
    print(T_main.result)
    print("运行时间:")
    print(end-start)

你可能感兴趣的:(python,numpy,多线程)