PyTorch中的矩阵乘法

torch.mm()

torch.mm(input, mat2, out=None) → Tensor

矩阵乘法,不进行 broadcast

torch.bmm()

输入1 :(b×n×m) tensor, 输入2:(b×m×p) tensor, 输出:(b×n×p) tensor.
batch 式的矩阵乘法,不broadcast

torch.matmul()

torch.matmul(input, other, out=None) → Tensor

矩阵乘法,有broadcast功能

  1. 如果输入的tensor都是一维,则计算点积:
a = torch.Tensor([1,2,3])
b = torch.Tensor([1,1,1])
torch.matmul(a,b)
# tensor(6.)
  1. 如果输入tensor都是二维矩阵,则计算矩阵乘法:
a = torch.Tensor(3,2)
b = torch.Tensor(2,3)
torch.matmul(a,b).shape
# torch.Size([3, 3])
  1. 如果是[b x m x k]与[b x k x n]形式的矩阵乘法,则进行batched matrix multiply得到[b x m x n]矩阵:
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
# torch.Size([10, 3, 5])

torch.mul()

  1. 如果输入tensor形状相同,则元素相乘:
a = torch.Tensor(2,3)
b = torch.Tensor(2,3)
torch.mul(a,b).shape
# torch.Size([2, 3])
  1. 其他:
a = torch.Tensor([1,2,3])
b = torch.Tensor([[1],[2],[3]])
torch.mul(a,b)
# tensor([[1., 2., 3.],
#         [2., 4., 6.],
#         [3., 6., 9.]])

总结

torch.mm()和torch.bmm()分别是单纯矩阵乘法和batch矩阵乘法,不进行broadcast,比较简单明了。torch.matmul()也是矩阵乘法,但是有broadcast,比较灵活,可以单纯矩阵乘法也可以batch矩阵乘法。
torch.mul()则是元素相乘。

你可能感兴趣的:(线性代数,pytorch)