torch
中常见的一些矩阵乘法和元素乘积,说白了无非就是以下四种,为了避免忘了,做个笔记
*
torch.mul()
torch.mm
torch.matmul
torch.dot
*
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1)
'''
tensor([[[1],
[1],
[1],
[0],
[0]],
[[1],
[1],
[1],
[1],
[1]]], torch.Size([2,1,5]))
'''
node_1 = node.unsqueeze(1)
'''
tensor([[[1, 1, 1, 0, 0]],
[[1, 1, 1, 1, 1]]], torch.Size([2,5,1]))
'''
'True, 值相同'
node_mask.unsqueeze(-1) * node_mask.unsqueeze(1) == /
node_mask.unsqueeze(-1) * node_mask.unsqueeze(1)
print(node_mask.unsqueeze(-1) * node_mask.unsqueeze(1))
'shape=[2,5,5]'
所以,乘法符号是对应的tensor
和元素乘。
2.torch.mul()
和上面一样,不同是有官方解释
torch.mul(input, value, out=None)
用标量值value
乘以输入input
的每个元素,并返回一个新的结果张量。 out=tensor∗value
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1) # [2,5,1]
node_1 = node.unsqueeze(1) # [2,1,5]
'shape=[2,5,5]'
torch.mul(node_0, node_1) #
'True,看来是相同的'
torch.mul(node_0, node_1) == node_mask.unsqueeze(-1) * node_mask.unsqueeze(1)
torch.matmul()和torch.mm()和torch.bmm()
shape=n×m
格式,只能用torch.mm()
和torch.matmul()
,不可以使用torch.bmm()
batch_size
,也就是sahpe=[batch_size, n, m]
,只能用torch.matmul()
和torch.bmm()
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1) # [2,5,1]
node_1 = node.unsqueeze(1) # [2,1,5]
'先看mm和matmul'
torch.mm(node_mask, node_mask.transpose(0,1)) # [2,2]
torch.matmul(node_mask, node_mask.transpose(0,1)) # [2,2]
# 下面 True
torch.mm(node_mask, node_mask.transpose(0,1)) == /
torch.matmul(node_mask, node_mask.transpose(0,1))
# error: torch.mm()就不可以在这种三维tensor下用
torch.mm(node_0, node_1) # 报错
'而torch.matmul()'
torch.matmul(node_0, node_1)
'看高维Tensor'
torch.matmul(node_0, node_1) # --> [2,1,1]
torch.bmm(node_0, node_1) # --> [2,1,1]
torch.bmm(node_mask.unsqueeze(1), node_mask.unsqueeze(-1)) # --> [2,1,1]
'True'
torch.matmul(node_0, node_1) == torch.bmm(node_0, node_1)
注意:dot很快运算速度远超于matmul