提示:文章只是作为自己学习中的记录、总结,不喜勿喷;同时,本人水平有限,如有错误,敬请指正。
torch.mul()、torch.mm()、torch.bmm()、torch.matmul()、@
进行矩阵点乘,可进行高维运算;X、Y维度一致时,以下两种情况可以正常计算:
1)X.shape= =Y.shape
例如:
import torch
X = torch.ones((3 ,2, 3))
Y = torch.ones((3, 2, 3))
torch.mul(X, Y).shape
##output:torch.Size([3, 2, 3])
2)X维度小于Y的维度时(广播机制),X 的维度大小要与Y的对应维度相匹配,如:X.shape=(2, 3),Y.shape=(3,2,3);X.shape=(1, 3),Y.shape=(3,2,3);X.shape=( 3),Y.shape=(3,2,3)。输出维度与X,Y中维度大的一致;
X = torch.ones((2, 3))
Y = torch.ones((3, 2, 3))
torch.mul(X, Y).shape
##output:torch.Size([3, 2, 3])
X = torch.ones((1, 3))
Y = torch.ones((3, 2, 3))
torch.mul(X, Y).shape
##output:torch.Size([3, 2, 3])
X = torch.ones((3))
Y = torch.ones((3, 2, 3))
torch.mul(X, Y).shape
##output:torch.Size([3, 2, 3])
torch.mm():只能进行二维的矩阵运算,X,Y维度满足矩阵乘法运算的要求;
torch.bmm():批量运算,对应批次做矩阵相乘,只能进行三维的矩阵运算,dim0:是批量大小,必须一致;(dim1,dim2):必须满足矩阵乘法运算要求;
X = torch.ones((2,3))
Y = torch.ones((3,3))
torch.mm(X, Y).shape
##output:torch.Size([2, 3])
X = torch.range(1,12).reshape((2,2,3))
Y = torch.ones((2,3,1))
torch.bmm(X, Y)
##output:tensor([[[ 6.],
[15.]],
[[24.],
[33.]]]),size:(2,2,1)
做矩阵乘法,两者等价,功能最为强大;可进行高维矩阵运算,维度不一样时通过广播机制也可以进行计算。
**维度的匹配**:要保证最后两个维度大小满足矩阵乘法的要求(a,b)*(b,c);其他维度大小:
如X.shape=(2,3,4),Y.shape=(**6,5,2,3 ,5** ,2,4,1),Y的维度(6,5,2,3,5)是相对于X多出的维度,可以是任意值大小;Y的维度(2,4,1)是和X的维度相对应的维度,因此后两维需满足矩阵乘法运算的要求;对应于X的dim0的Y维度,需要保持一样或是1。
## 两者等价
X = torch.rand((2,2,4))
Y = torch.rand((2,4,1))
a = X @ Y
b = torch.matmul(X, Y)
a == b
##output:tensor([[[True],
[True]],
[[True],
[True]]])
##两者进行高维运算
## 维度匹配规则和torch.mul()、torch.bmm()类似
X = torch.rand((2,4))
Y = torch.rand((2,2,4,1))
a = X @ Y
b = torch.matmul(X, Y)
a.shape,a == b
##output:(torch.Size([2, 2, 2, 1]),
tensor([[[[True],
[True]],
[[True],
[True]]],
[[[True],
[True]],
[[True],
[True]]]]))
X = torch.rand((1,1,1,2,1,2,4))
Y = torch.rand((2,2,4,1))
a = X @ Y
b = torch.matmul(X, Y)
a.shape,a == b
##output:(torch.Size([1, 1, 1, 2, 2, 2, 1]),
tensor([[[[[[[True]],
[[True]]],
[[[True]],
[[True]]]]]]]))
X = torch.rand((2,3,4))
Y = torch.rand((6,5,2,3,5,2,4,1))
a = X @ Y
b = torch.matmul(X, Y)
a.shape
##output:torch.Size([6, 5, 2, 3, 5, 2, 3, 1])