torch中的乘法符号(*),torch.mm()和torch.matmul(),torch.mul(), torch.bmm()

前言

torch中常见的一些矩阵乘法和元素乘积,说白了无非就是以下四种,为了避免忘了,做个笔记

  1. 乘法符号 *
  2. torch.mul()
  3. torch.mm
  4. torch.matmul
  5. torch.dot

1. 对比

  1. 乘法符号*
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])

node_0 = node.unsqueeze(-1)
'''
tensor([[[1],
         [1],
         [1],
         [0],
         [0]],
        [[1],
         [1],
         [1],
         [1],
         [1]]], torch.Size([2,1,5]))
'''
node_1 = node.unsqueeze(1)
'''
tensor([[[1, 1, 1, 0, 0]],
        [[1, 1, 1, 1, 1]]], torch.Size([2,5,1]))
'''

'True, 值相同'
node_mask.unsqueeze(-1) * node_mask.unsqueeze(1) == /
node_mask.unsqueeze(-1) * node_mask.unsqueeze(1)

print(node_mask.unsqueeze(-1) * node_mask.unsqueeze(1))
'shape=[2,5,5]'

所以,乘法符号是对应的tensor和元素乘。

2.torch.mul()

和上面一样,不同是有官方解释

torch.mul(input, value, out=None)

用标量值value乘以输入input的每个元素,并返回一个新的结果张量。 out=tensor∗value

# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1) # [2,5,1]
node_1 = node.unsqueeze(1) # [2,1,5]
'shape=[2,5,5]'
torch.mul(node_0, node_1) #
'True,看来是相同的'
torch.mul(node_0, node_1) ==  node_mask.unsqueeze(-1) * node_mask.unsqueeze(1)
  1. torch.matmul()和torch.mm()和torch.bmm()
  • 在矩阵的情况下,矩阵就是shape=n×m格式,只能用torch.mm()torch.matmul()不可以使用torch.bmm()
  • 但是大多数情况都是带有batch_size,也就是sahpe=[batch_size, n, m],只能用torch.matmul()torch.bmm()
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1) # [2,5,1]
node_1 = node.unsqueeze(1) # [2,1,5]

'先看mm和matmul'
torch.mm(node_mask, node_mask.transpose(0,1)) # [2,2]
torch.matmul(node_mask, node_mask.transpose(0,1)) # [2,2]
# 下面 True
torch.mm(node_mask, node_mask.transpose(0,1)) == /
torch.matmul(node_mask, node_mask.transpose(0,1))

# error:  torch.mm()就不可以在这种三维tensor下用
torch.mm(node_0, node_1) # 报错
'而torch.matmul()'
torch.matmul(node_0, node_1)


'看高维Tensor'
torch.matmul(node_0, node_1) # --> [2,1,1]
torch.bmm(node_0, node_1) # --> [2,1,1]
torch.bmm(node_mask.unsqueeze(1), node_mask.unsqueeze(-1)) # --> [2,1,1]

'True'
torch.matmul(node_0, node_1) == torch.bmm(node_0, node_1) 

注意:dot很快运算速度远超于matmul

你可能感兴趣的:(pytorch,python)