torch.matmul() 将两个张量相乘划分成了五种情形:
一维 × 一维、二维 × 二维、一维 × 二维、二维 × 一维、涉及到三维及三维以上维度的张量的乘法。
1.如果两个张量都是一维的,即 torch.Size([n]) ,此时返回两个向量的点积。作用与 torch.dot() 相同,同样要求两个一维张量的元素个数相同。
2.如果两个参数都是二维张量,那么将返回矩阵乘积。作用与 torch.mm() 相同,同样要求两个张量的形状需要满足矩阵乘法的条件,即(n×m)×(m×p)=(n×p)
3.如果第一个参数是一维张量,第二个参数是二维张量,那么在一维张量的前面增加一个维度,然后进行矩阵乘法,矩阵乘法结束后移除添加的维度。
第一个参数是大小为n的1-D张量,第二个参数为(n×m)张量,那么在一维张量的前面增加一个维度,也就是第一个参数变成(1 x n)的张量,正常的torch.mm。
4.如果第一个参数是二维张量(矩阵),第二个参数是一维张量(向量),那么将返回矩阵×向量的积。作用与 torch.mv() 相同。另外要求矩阵的形状和向量的形状满足矩阵乘法的要求。
第一个参数是(n×m)张量,第二个参数是大小为m的1-D张量,也就是第二个参数是(mx1)张量,输出将是大小为n的1-D,也就是(nx1)的张量,正常的torch.mm。
5.如果两个参数均至少为一维,且其中一个参数的 ndim > 2。这条规则将所有涉及到三维张量及三维以上的张量的乘法分为三类:一维张量 × 高维张量、高维张量 × 一维张量、二维及二维以上的张量 × 二维及二维以上的张量。
如果第一个参数是一维张量,那么在此张量之前增加一个维度。
如果第二个参数是一维张量,那么在此张量之后增加一个维度。
由于上述两个规则,所有涉及到一维张量和高维张量的乘法都被转变为二维及二维以上的张量 × 二维及二维以上的张量。
然后除掉最右边的两个维度,对剩下的维度进行广播。然后就可以进行批量矩阵乘法。