torch.mm, torch.matmul 和torch.mul的区别: RuntimeError: 2D tensors expected, got 4D

点乘:torch.mul

torch.mul(a, b)是矩阵a和b点乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵

 

矩阵乘法:torch.mm和torch.mul

torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

torch.matmul则可以用于两个图片矩阵相乘,如a的维度是[bs,3,128,256]; b的维度是[bs, 3, 256,512], 则torch.mm(a, b)就会报错;但torch.matmul(a,b)就不会报错!

如下所示:

(Pdb) TL = torch.matmul(T,L.view(T.shape[0],T.shape[1],-1,1))
(Pdb) TL = torch.mm(T,L.view(T.shape[0],T.shape[1],-1,1))
*** RuntimeError: 2D tensors expected, got 4D, 4D tensors at /tmp/pip-req-build-ocx5vxk7/aten/src/THC/generic/THCTensorMathBlas.cu:282

 

你可能感兴趣的:(Code,Deep,Learning,pytorch,python,乘法,矩阵乘法)