pytorch函数mm() mul() matmul()区别

文章目录

    • 1、torch.mul()
    • 2、torch.mm()
    • 3、torch.matmul()
      • 3.1 输入都是二维
      • 3.2 输入都是三维
      • 3.3 输入的维度不同

1、torch.mul()

  • torch.mul(a, b)是矩阵a和b对应位相乘
  • torch.mul(a, b)中a和b的维度相等,但是,对应维度上的数字可以不同,可以用利用广播机制扩展到相同的形状,再进行点乘操作
# 比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
>>> a = torch.rand(1, 2)
>>> b = torch.rand(1, 2)
>>> torch.mul(a, b)  # 返回 1*2 的tensor

# 乘列向量
>>> a = torch.ones(3,4) 
>>> a
tensor([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
       [2.],
       [3.]])
>>> torch.mul(a, b)
tensor([[1., 1., 1., 1.],
       [2., 2., 2., 2.],
       [3., 3., 3., 3.]])

2、torch.mm()

  • torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(3, 4),b的维度是(4, 2),返回的就是(3, 2)的矩阵torch.mm(a, b)针对二维矩阵
>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
        [4., 4.],
        [4., 4.]])

mm()是mutmul()的简称?

3、torch.matmul()

  • torch.matmul(a, b)也是一种类似于矩阵相乘操作的tensor联乘操作,一般是高维矩阵a和b相乘,但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。

3.1 输入都是二维

  • 当输入都是二维时,就是普通的矩阵乘法,和tensor.mm()函数用法相同。
    pytorch函数mm() mul() matmul()区别_第1张图片

3.2 输入都是三维

  • 下面看一个两个都是3维的例子:
    pytorch函数mm() mul() matmul()区别_第2张图片
    将b的第0维1broadcast成2提出来,后两维做矩阵乘法即可。

3.3 输入的维度不同

  • 当输入有多维时,把多出的一维作为batch提出来,其他部分做矩阵乘法。
    pytorch函数mm() mul() matmul()区别_第3张图片
  • 再看一个复杂一点的,是官网的例子:
    pytorch函数mm() mul() matmul()区别_第4张图片
    首先把a的第0维2作为batch提出来,则a和b都可看作三维。再把a的1broadcat成5,提取公因式5。(这样说虽然不严谨,但是便于理解。)然后a剩下(3,4),b剩下(4,2),做矩阵乘法得到(3,2)。

参考:https://www.jianshu.com/p/e277f7fc67b3
参考:https://blog.csdn.net/qsmx666/article/details/105783610

你可能感兴趣的:(pytorch)