一文搞懂pytorch中的乘法

各种乘法运算中会多次涉及到广播机制。先说下numpy中的广播机制。了解的直接跳过看第二部分即可

一 numpy广播机制Broadcast

原理:python在进行numpy算术运算采用的是element-wise方式(逐元素操作的方式),此时要求两个数据的维度必须相同。

维度不同时,会触发广播操作使其维度相同。不满足广播操作的情况下会直接报错。

先理解下维度,便于理解broadcast,数据的维度指两个方面,维度的个数和维度的大小。

如:a = np.ones(4,3)维度个数是2,第一维大小是4,第二维大小是3

广播的执行过程:

1.如果维度个数不同,则在维度较少的左边补1,使得维度的个数相同。

2.各维度的维度大小不同时,如果有维度为1的,直接将该维拉伸至维度相同

#%%
import numpy as np
a = np.arange(12).reshape((3,1,4))
b = np.array([1,2,3,4]).reshape((4,1))
print(a.shape,b.shape)
print((a+b).shape)
#输出
(3, 1, 4) (4, 1)
(3, 4, 4)

二 pytorch中的乘法

1. torch.dot()

向量点乘,得到的结果是scale标量。对应元素相乘并相加

a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])
print(torch.sum(torch.mul(a,b)))
print(torch.dot(a,b))
tensor(32)
tensor(32)

2.torch.mul和*

使用element wise相乘方式,逐元素运算,维度不相同时执行广播机制

a = torch.tensor(list(range(12))).view(4,3)
b = torch.tensor([1,2,3])
c = torch.mul(a,b)
d = a * b
print(c.shape)
print(d.shape)
tensor(32)
tensor(32)

3.torch.mm()

矩阵相乘,

#torch.mm() 矩阵相乘
a = torch.randn(4,3)
b = torch.randn(3,5)
print(torch.mm(a,b).shape)
torch.Size([4, 5])

注意:矩阵乘法的维度必须满足mn x nk,不匹配会报错。不会进行广播机制。

a = torch.randn(4,3)
b = torch.randn(3,)
print(torch.mm(a,b).shape)

 维度不匹配会报错

---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

 in 
      2 a = torch.tensor(np.arange(12)).view(4,3)
      3 b = torch.tensor(np.arange(3)).view(3,)
----> 4 print(torch.mm(a,b).shape)
      5 
      6 

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

4 torch.bmm()

批量矩阵相乘。 

注意 bmm中两个矩阵的后两维度必须满足矩阵相乘的条件,mn nk

torch.mm(bmn, bnk) = bnk,不会进行广播机制,想使用广播机制可以用torch.matmul()

a = torch.randn(6,3,4)
b = torch.randn(6,4,5)
print(torch.matmul(a,b).shape)
torch.Size([6, 3, 5])

5 torch.matmul()

#两个tensor的矩阵乘法,具体操作取决于两个tensor的shape,按两个矩阵维度的不同可分为以下五种

1)如果两个矩阵都是1维,则执行向量点乘dot操作。

2)如果两个矩阵都是2维,这执行矩阵相乘mm操作。

3)第一个矩阵是1维,第二个是2维,行向量乘以矩阵。(线性代数矩阵和向量相乘),向量x矩阵相当于矩阵行向量的线性组合

4)第一个矩阵是2维,第二个是1维,矩阵乘以列向量。矩阵x向量,相当于矩阵列向量的线性组合

5)如果两个都至少是1维,并且至少一个维度大于2。会执行batch矩阵相乘torch.bmm。后两维进行矩阵相乘mm。

相当于将每个矩阵看做一个元素,然后逐元素进行矩阵乘法。例如a.shape=[j,k,m,n] b.shape=[k,n,k]。将每个矩阵看做一个element时,a.shape=[j,k],b.shape=[k]

然后按element-wise的方式进行mm矩阵乘法。element-wise方式要求维度个数和大小必须相同。不相同执行广播机制。

测试:

#情形1
a = torch.randn(3)
b = torch.randn(3)
torch.matmul(a,b).shape
#torch.Size([])

#情形2
a = torch.randn(3,4)
b = torch.randn(4,5)
torch.matmul(a,b).shape
#torch.Size([3, 5])

#情形3
a = torch.randn(3)
b = torch.randn(3,5)
torch.matmul(a,b).shape
#torch.Size([5])

#情形4
a = torch.randn(3,4)
b = torch.randn(4)
torch.matmul(a,b).shape
#torch.Size([3])

#情形5 
#5.1 batch 矩阵相乘
a = torch.randn(5,6,3,4)
b = torch.randn(5,6,4,5)
torch.matmul(a,b).shape
#torch.Size([5, 6, 3, 5])

#5.2 缺少的非-矩阵维度(不参与矩阵运算的维度)会进行广播
a = torch.randn(5,6,3,4)
b = torch.randn(6,4,5)
torch.matmul(a,b).shape
#torch.Size([5, 6, 3, 5])

#5.3 batch 行向量x矩阵 
a = torch.randn(3)
b = torch.randn(5,6,3,4)
torch.matmul(a,b).shape
#torch.Size([5, 6, 4])

#5.4 batch 矩阵x列向量
a = torch.randn(5,6,7,8)
b = torch.randn(8)
torch.matmul(a,b).shape
#torch.Size([5, 6, 7])

 

你可能感兴趣的:(pytorch,pytorch)