torch.mul() 、 torch.mm() 及torch.matmul()的区别
举例
import torch
a = torch.rand(3, 4)
b = torch.rand(3, 4)
c = torch.rand(4, 5)
print(torch.mul(a, b).size()) # 返回 1*2 的tensor
print(torch.mm(a, c).size()) # 返回 1*3 的tensor
print(torch.mul(a, c).size()) # 由于a、b维度不同,报错
输出
torch.Size([3, 4])
torch.Size([3, 5])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-aea68cb5481f> in <module>
7 print(torch.mul(a, b).size()) # 返回 1*2 的tensor
8 print(torch.mm(a, c).size()) # 返回 1*3 的tensor
----> 9 print(torch.mul(a, c).size()) # 由于a、b维度不同,报错
RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1
参考:https://pytorch.org/docs/stable/torch.html#torch.bmm
torch.bmm(input, mat2, out=None) → Tensor
torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。
参数:
input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。
output:输出结果
并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。
举例
import torch
x = torch.rand(2,4,5)
y = torch.rand(2,5,7)
print(torch.bmm(x,y).size())
输出
torch.Size([2, 4, 7])
torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。
参数:
input,other:两个要进行操作的tensor结构
output:结果
一些规则约定:
输入
import torch
x = torch.rand(5) #1D
x1 = x.view(1,-1)
y = torch.rand(5,3) #2D
print(x1.size())
print(x.size())
print(y.size())
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
print(torch.matmul(x1,y),'\n',torch.matmul(x,y).size())
输出
torch.Size([1, 5])
torch.Size([5])
torch.Size([5, 3])
tensor([1.5374, 1.3291, 1.8289])
torch.Size([3])
tensor([[1.5374, 1.3291, 1.8289]])
torch.Size([3])
举例
import torch
x = torch.rand(3) #1D
x1 = x.view(-1,1)
y = torch.rand(5,3) #2D
print(x1.size())
print(x.size())
print(y.size())
print(torch.matmul(y,x),'\n',torch.matmul(y,x).size())
print(torch.matmul(y,x1),'\n',torch.matmul(y,x1).size())
输出
torch.Size([3, 1])
torch.Size([3])
torch.Size([5, 3])
tensor([0.6472, 0.7025, 0.2358, 0.2873, 0.5696])
torch.Size([5])
tensor([[0.6472],
[0.7025],
[0.2358],
[0.2873],
[0.5696]])
torch.Size([5, 1])
(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)
言而总之,总而言之:matmul()根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。
参考文献:https://www.jianshu.com/p/e277f7fc67b3