pytorch中的乘法:mul matmul mm bmm @ *

目录

mul

broadcast

运算符*

matmul

运算符@

mm

bmm


mul

torch.mul(a, b)

  • 如果a和b的shape相同,则结果是对应位置的元素相乘,输出的shape不变。注意:mul不是矩阵乘法。
  • 如果a和b的shape不同,则两个shape必须是broadcastable的(见下文)。首先对a和b进行broadcast,之后a和b的shape就相同了,然后对应元素相乘,输出的shape是broadcast之后的。
  • 此外,a和b可以都是标量,此时就是普通标量乘法;a和b也可以一个是标量另一个是多维的。假如a是标量,b是多维的,此时就是用a乘以b的每一个元素,输出shape和b相同,即相当于对标量a做了broadcast。

如果a不是标量时,torch.mul(a, b)相当于a.mul(b)。

broadcast

两个shape必须满足以下两个要求才是broadcastable的:

  1. a和b至少都是1维的,比如:向量是1维的,矩阵是2维的
  2. 从后往前遍历a和b的维度,对应位置的维度必须相同,或者有一个是1,或者有一个不存在。

举例如下:

例1:

a的shape是(5, 3, 4, 1)

b的shape是(    3, 1, 1)

从后往前看,最后边的1和1是相同的;然后4和1,有一个是1;然后3和3相同;然后5对应的位置,b的不维度不存在;因此a和b是broadcastable的。

例2:

a的shape是(5, 2, 4, 1)

b的shape是(    3, 1, 1)

从后往前看,最后边的1和1是相同的;然后4和1,有一个是1;然后3和2不相同;因此a和b不是broadcastable的。

broadcast之后,所有维度都取较大的值,举例如下:

例1:

a的shape是(5, 1, 4, 1)

b的shape是(    3, 1, 1)

broadcast之后的shape是(5, 3, 4, 1)

运算符*

通过实践发现a*b和torch.mul(a, b)是等价的。从python官网(https://docs.python.org/3/library/operator.html)查到如下信息:

matmul

torch.matmul(a, b)或者a.matmul(b)

  • 如果a和b都是1维的,即都是向量,则进行向量点乘,相当于shape分别是(1, n)和(n, 1)的两个矩阵做矩阵乘法,输出是一个标量。
  • 如果a和b都是2维的,即都是矩阵,则进行矩阵乘法,要求a的列数等于b的行数,即a的shape是(N, M),b的shape是(M, K)。
  • 如果a是1维的,b是2维的,则要求a的长度等于b的行数;假设a的shape是4,b的shape是(4, 6),首先将a的shape改成(1, 4);然后进行矩阵乘法,结果的shape是(1, 6);最后将结果改为1维,即输出长为6的向量。
  • 如果a是2维的,b是1维的,则要求a的列数等于b的长度;假设a的shape是(4, 5),b的shape是5,首先将b的shape改成(5, 1);然后进行矩阵乘法,结果的shaoe是(4, 1);最后将结果改为1维,即输出长为4的向量。

注意:这里红色部分的情况和蓝色部分没有本质区别。

  • 如果a是大于2维的,b是2维的,则进行批量矩阵乘法;假设a的shape是(2, 3, 4),b的shape是(4, 2),则我们把a看做2个shape为(3, 4)的子矩阵,即batch size等于2,我们记为a0和a1。此时的运算相当于是做a0和b的矩阵乘法,以及a1和b的矩阵乘法,显然b被重复使用了2次,我们称之为broadcast(见上文)。显然:a的最后2维和b的shape必须满足矩阵乘法。对于a是4维的情况,也是一样的,假设a的shape是(3, 4, 5, 6),则b的shape需要是(6, N),此时把a看做是3*4=12个shape为(5, 6)的矩阵,如下图。此时b被重复使用了12次
a00 a01 a02 a03
a10 a11 a12 a13
a20 a21 a22 a23
  • 如果a是2维的,b是大于2维的,和上述情况是一样的。
  • 如果a是大于2维的,b是1维的,或者反过来a是1维的,b是大于2维的,则和上述也是一样的,只不过需要把1维的向量补充成2维的,即把长为N的向量变为(1, N)或者(N, 1)的矩阵,用以满足做矩阵乘法。
  • 如果a和b都是大于2维的,其实也是一样的,假设a的shape是(2, 1, 6, 4, 5),b的shape是(4, 6, 5, 9),则必须保证a和b的最后2维满足矩阵乘法;此时我们把a看做是2*1*6个shape为(4, 5)的子矩阵,同理b是4*6个shape为(5, 9)的子矩阵;此时我们把子矩阵看做独立的单元,即我们不去关心矩阵乘法的细节,两个矩阵相乘就看做两个标量相乘,那么现在就是shape为(2, 1, 6)和shape为(4, 6)的矩阵做上文的mul操作(即对应位置元素相乘,而不是矩阵乘法)。因此我们需要做boradcast,都变成(2, 4, 6),则结果的shape也是(2, 4, 6),此时结果的每个元素不再是一个个标量,而是一个个子矩阵,每个子矩阵的shape都是(4, 9),因为a的每个元素都是shape为(4, 5)的子矩阵,b的每个元素都是shape为(5, 9)的子矩阵

注意:此时不仅要求a和b的最后2维满足矩阵乘法,还要求前边的维度是broadcastable的,见上文。

运算符@

通过实践发现a@b和torch.matmul(a, b)是等价的。从python官网(https://docs.python.org/3/library/operator.html)查到如下信息:

mm

其实是matmul的简化版本,要求输入的a和b必须都是2维的,且必须满足矩阵乘法。

bmm

b是batch的意思,即批量矩阵乘法,其实是matmul的简化版本,要求输入的a和b必须都是3维的,且第0维必须相等,最后2维必须满足矩阵乘法。

 

你可能感兴趣的:(pytorch,乘法,pytorch)