torch.mm(input, mat2, out=None) → Tensor
矩阵乘法,不进行 broadcast
输入1 :(b×n×m) tensor, 输入2:(b×m×p) tensor, 输出:(b×n×p) tensor.
batch 式的矩阵乘法,不broadcast
torch.matmul(input, other, out=None) → Tensor
矩阵乘法,有broadcast功能
a = torch.Tensor([1,2,3])
b = torch.Tensor([1,1,1])
torch.matmul(a,b)
# tensor(6.)
a = torch.Tensor(3,2)
b = torch.Tensor(2,3)
torch.matmul(a,b).shape
# torch.Size([3, 3])
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
# torch.Size([10, 3, 5])
a = torch.Tensor(2,3)
b = torch.Tensor(2,3)
torch.mul(a,b).shape
# torch.Size([2, 3])
a = torch.Tensor([1,2,3])
b = torch.Tensor([[1],[2],[3]])
torch.mul(a,b)
# tensor([[1., 2., 3.],
# [2., 4., 6.],
# [3., 6., 9.]])
torch.mm()和torch.bmm()分别是单纯矩阵乘法和batch矩阵乘法,不进行broadcast,比较简单明了。torch.matmul()也是矩阵乘法,但是有broadcast,比较灵活,可以单纯矩阵乘法也可以batch矩阵乘法。
torch.mul()则是元素相乘。