torch.bmm()和torch.matmul()函数的用法和区别,矩阵相乘

torch.bmm()和torch.matmul()都是矩阵乘法的运算函数,区别是,torch.matmul更强大。

两者都可以支持3维的矩阵运算,实际是第一维只是找下标,后面2维才是矩阵,然后对应做矩阵乘法。

import torch

x = torch.rand(2,3,2)#2个3x2大小的矩阵
y = torch.rand(2,2,4)#2个2x4大小的矩阵
#矩阵乘法的结果为2个3x4大小的矩阵
print('矩阵x:',x)
print('矩阵y:',y)
print('matmul运算:')
print(torch.matmul(x,y))
print('bmm运算:')
print(torch.bmm(x,y))

运算结果

矩阵x: tensor([[[0.9071, 0.4671],
         [0.4678, 0.9219],
         [0.2115, 0.1356]],

        [[0.9514, 0.4708],
         [0.6294, 0.0667],
         [0.2318, 0.7602]]])
矩阵y: tensor([[[0.3470, 0.9198, 0.8931, 0.3055],
         [0.2137, 0.0132, 0.4909, 0.8819]],

        [[0.4307, 0.1535, 0.2024, 0.9597],
         [0.5891, 0.7382, 0.1799, 0.8932]]])
matmul运算:
tensor([[[0.4146, 0.8405, 1.0394, 0.6891],
         [0.3594, 0.4424, 0.8703, 0.9559],
         [0.1024, 0.1963, 0.2554, 0.1842]],

        [[0.6871, 0.4936, 0.2772, 1.3335],
         [0.3104, 0.1459, 0.1394, 0.6637],
         [0.5477, 0.5968, 0.1837, 0.9016]]])
bmm运算:
tensor([[[0.4146, 0.8405, 1.0394, 0.6891],
         [0.3594, 0.4424, 0.8703, 0.9559],
         [0.1024, 0.1963, 0.2554, 0.1842]],

        [[0.6871, 0.4936, 0.2772, 1.3335],
         [0.3104, 0.1459, 0.1394, 0.6637],
         [0.5477, 0.5968, 0.1837, 0.9016]]])

区别来了,matmul可支持单独的矩阵运算,但是bmm不支持,如下:

import torch

x = torch.rand(3,2)#单独的3x2大小的矩阵
y = torch.rand(2,4)#单独的2x4大小的矩阵

print('矩阵x:',x)
print('矩阵y:',y)
print('matmul运算:')
print(torch.matmul(x,y))
print('bmm运算:')
print(torch.bmm(x,y))

运算结果

矩阵x: tensor([[0.6754, 0.4208],
        [0.7033, 0.2135],
        [0.2919, 0.3624]])
矩阵y: tensor([[0.5826, 0.6710, 0.3636, 0.3826],
        [0.8267, 0.8261, 0.6235, 0.8551]])
matmul运算:
tensor([[0.7414, 0.8009, 0.5080, 0.6183],
        [0.5863, 0.6483, 0.3888, 0.4517],
        [0.4696, 0.4952, 0.3321, 0.4215]])
bmm运算:
Traceback (most recent call last):
  File "E:/储物柜/study/研究生/深度学习/works/SiameseNetwork.py", line 11, in 
    print(torch.bmm(x,y))
RuntimeError: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)

后面报错的原因是bmm只支持3维矩阵的运算。

你可能感兴趣的:(深度学习,工具,pytorch,深度学习,python)