各种乘法运算中会多次涉及到广播机制。先说下numpy中的广播机制。了解的直接跳过看第二部分即可
原理:python在进行numpy算术运算采用的是element-wise方式(逐元素操作的方式),此时要求两个数据的维度必须相同。
维度不同时,会触发广播操作使其维度相同。不满足广播操作的情况下会直接报错。
先理解下维度,便于理解broadcast,数据的维度指两个方面,维度的个数和维度的大小。
如:a = np.ones(4,3)维度个数是2,第一维大小是4,第二维大小是3
广播的执行过程:
1.如果维度个数不同,则在维度较少的左边补1,使得维度的个数相同。
2.各维度的维度大小不同时,如果有维度为1的,直接将该维拉伸至维度相同
#%%
import numpy as np
a = np.arange(12).reshape((3,1,4))
b = np.array([1,2,3,4]).reshape((4,1))
print(a.shape,b.shape)
print((a+b).shape)
#输出
(3, 1, 4) (4, 1)
(3, 4, 4)
1. torch.dot()
向量点乘,得到的结果是scale标量。对应元素相乘并相加
a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])
print(torch.sum(torch.mul(a,b)))
print(torch.dot(a,b))
tensor(32)
tensor(32)
2.torch.mul和*
使用element wise相乘方式,逐元素运算,维度不相同时执行广播机制
a = torch.tensor(list(range(12))).view(4,3)
b = torch.tensor([1,2,3])
c = torch.mul(a,b)
d = a * b
print(c.shape)
print(d.shape)
tensor(32)
tensor(32)
3.torch.mm()
矩阵相乘,
#torch.mm() 矩阵相乘
a = torch.randn(4,3)
b = torch.randn(3,5)
print(torch.mm(a,b).shape)
torch.Size([4, 5])
注意:矩阵乘法的维度必须满足mn x nk,不匹配会报错。不会进行广播机制。
a = torch.randn(4,3)
b = torch.randn(3,)
print(torch.mm(a,b).shape)
维度不匹配会报错
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
in
2 a = torch.tensor(np.arange(12)).view(4,3)
3 b = torch.tensor(np.arange(3)).view(3,)
----> 4 print(torch.mm(a,b).shape)
5
6
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
4 torch.bmm()
批量矩阵相乘。
注意 bmm中两个矩阵的后两维度必须满足矩阵相乘的条件,mn nk
torch.mm(bmn, bnk) = bnk,不会进行广播机制,想使用广播机制可以用torch.matmul()
a = torch.randn(6,3,4)
b = torch.randn(6,4,5)
print(torch.matmul(a,b).shape)
torch.Size([6, 3, 5])
5 torch.matmul()
#两个tensor的矩阵乘法,具体操作取决于两个tensor的shape,按两个矩阵维度的不同可分为以下五种
1)如果两个矩阵都是1维,则执行向量点乘dot操作。
2)如果两个矩阵都是2维,这执行矩阵相乘mm操作。
3)第一个矩阵是1维,第二个是2维,行向量乘以矩阵。(线性代数矩阵和向量相乘),向量x矩阵相当于矩阵行向量的线性组合
4)第一个矩阵是2维,第二个是1维,矩阵乘以列向量。矩阵x向量,相当于矩阵列向量的线性组合
5)如果两个都至少是1维,并且至少一个维度大于2。会执行batch矩阵相乘torch.bmm。后两维进行矩阵相乘mm。
相当于将每个矩阵看做一个元素,然后逐元素进行矩阵乘法。例如a.shape=[j,k,m,n] b.shape=[k,n,k]。将每个矩阵看做一个element时,a.shape=[j,k],b.shape=[k]
然后按element-wise的方式进行mm矩阵乘法。element-wise方式要求维度个数和大小必须相同。不相同执行广播机制。
测试:
#情形1
a = torch.randn(3)
b = torch.randn(3)
torch.matmul(a,b).shape
#torch.Size([])
#情形2
a = torch.randn(3,4)
b = torch.randn(4,5)
torch.matmul(a,b).shape
#torch.Size([3, 5])
#情形3
a = torch.randn(3)
b = torch.randn(3,5)
torch.matmul(a,b).shape
#torch.Size([5])
#情形4
a = torch.randn(3,4)
b = torch.randn(4)
torch.matmul(a,b).shape
#torch.Size([3])
#情形5
#5.1 batch 矩阵相乘
a = torch.randn(5,6,3,4)
b = torch.randn(5,6,4,5)
torch.matmul(a,b).shape
#torch.Size([5, 6, 3, 5])
#5.2 缺少的非-矩阵维度(不参与矩阵运算的维度)会进行广播
a = torch.randn(5,6,3,4)
b = torch.randn(6,4,5)
torch.matmul(a,b).shape
#torch.Size([5, 6, 3, 5])
#5.3 batch 行向量x矩阵
a = torch.randn(3)
b = torch.randn(5,6,3,4)
torch.matmul(a,b).shape
#torch.Size([5, 6, 4])
#5.4 batch 矩阵x列向量
a = torch.randn(5,6,7,8)
b = torch.randn(8)
torch.matmul(a,b).shape
#torch.Size([5, 6, 7])