pytorch中torch.mul() 和 torch.mm() 的区别

torch.mul(a, b) 是矩阵a和b对应位相乘,a和b的维度必须相等。
torch.mm(a, b) 是矩阵a和b矩阵相乘

import torch

a = torch.rand(1, 2)
b = torch.rand(1, 2)
c = torch.rand(2, 3)

print(torch.mul(a, b))  # 返回 1*2 的tensor
print(torch.mm(a, c))   # 返回 1*3 的tensor
print(torch.mul(a, c))  # 由于a、b维度不同,报错

你可能感兴趣的:(深度学习,mul和mm区别)