在深度学习中经常会遇到不同维度的矩阵相乘的情况,本文会通过一些例子来展示不同维度矩阵乘法的过程。
总体原则:在高维矩阵中取与低维矩阵相同维度的子矩阵来与低维矩阵相乘,结果再按子矩阵的排列顺序还原为高维矩阵。相乘结果的维度与原来的高维矩阵一致。
具体来说,当一方为一维矩阵时,另一方取其最后一维子矩阵来做乘法;当两方都是大于等于2维的矩阵时,取各自的最后两维构成的子矩阵来做乘法,其他维度体现结果的拼接信息,不参与运算(为batch训练提供了便利,batch中各样本的顺序在矩阵运算前后保持一致)。
实例:下面我们从低维到高维,依次演示不同维度矩阵相乘的结果。
二维矩阵依次取出一维的行向量与一维矩阵做内积
#二维乘一维
import numpy as np
a = np.linspace(1,4,4).reshape(2,2)
b = np.array([1,1])
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
三维矩阵包含两个二维矩阵,分别将这两个二维矩阵与一维矩阵相乘(乘积为一维),结果按原来的顺序拼接起来,构成一个二维矩阵
#三维乘一维
import numpy as np
a = np.linspace(1,8,8).reshape(2,2,2)
b = np.array([1,1])
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
最常见的矩阵相乘形式
#二维乘二维
import numpy as np
a = np.linspace(1,4,4).reshape(2,2)
b = np.ones((2,2))
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
将三维矩阵中的后两维组成的二维子矩阵分别与二维矩阵相乘(二维),结果再按原顺序拼接起来(三维)
#相当于三维矩阵里的二维分量分别与二维矩阵相乘,再拼接起来
import numpy as np
a=np.linspace(1,8,8).reshape(2,2,2)
# print(a)
b = np.array([[1,0],[0,1]]) #单位矩阵
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
两个三维矩阵中对应位置的二维子矩阵分别相乘,结果按第0维分量更多的那个矩阵的结构拼接。
注意:,并不是任意两个三维矩阵都能相乘,其必须满足两个条件:
1
:两个矩阵的后两个维度构成的二维矩阵之间必须满足二维矩阵相乘的条件,即第一个矩阵的列数等于第二个矩阵的行数
2
:两个矩阵的第0维分量数必须相等(每个分量对应相乘) 或 有一方为1(broadcast-广播机制) ----反例见下方第3种情况
#三维乘三维 (2,2,2)*(2,2,2)
#计算时都是二维乘二维,第三维度反映二维矩阵的拼接信息;对应位置二维矩阵相乘
import numpy as np
a=np.linspace(1,8,8).reshape(2,2,2)
# print(a)
e = np.array([[[1,0],[0,1]]])
f = np.array([[[0,1],[1,0]]])
b = np.vstack((e,f))
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
#三维乘三维 (2,2,2)*(1,2,2)
#广播机制(broadcast)
import numpy as np
a=np.linspace(1,8,8).reshape(2,2,2)
# print(a)
b = np.array([[[0,1],[1,0]]])
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
第0维分量数不满足条件2,不能相乘
# 三维乘三维--不同形状:(4,2,2)*(2,2,2)
import numpy as np
a=np.linspace(1,16,16).reshape(4,2,2)
# print(a)
e = np.array([[[1,0],[0,1]]])
f = np.array([[[0,1],[1,0]]])
b = np.vstack((e,f))
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('c:\n',c)
与三维乘三维类似,可乘条件2改为:除最后两维外,每一维的分量数必须对应相等(每个分量对应相乘) 或 有一方为1(broadcast-广播机制)
#各维度的分量相互对应,最终仍是计算二维乘二维
import numpy as np
a=np.linspace(1,16,16).reshape(2,2,2,2)
b = np.ones((2,2,2,2)) #全1矩阵
c = np.matmul(a,b)
print('a:\n',a)
print('b:\n',b)
print('ab:\n',c)
知乎:多维矩阵相乘的可视化