下面简单回顾一下矩阵中的乘法:(严谨的说,其实应该说是矩阵乘法和矩阵内积)
1、矩阵乘法
矩阵乘法也就是我们常说的矩阵向量积(也称矩阵外积、矩阵叉乘)
它要求前一个矩阵的行数等于后一个矩阵的列数,其计算方法是计算结果的每一行元素为前一个矩阵的每一行元素与后一个矩阵的每一列对应元素相乘,之后求和。下面 2 ∗ 3 2*3 2∗3矩阵与 3 ∗ 5 3*5 3∗5矩阵为例:
[ 1 1 1 1 1 1 ] × [ 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 ] = [ 3 6 9 12 15 3 6 9 12 15 ] \begin{gathered} \begin{bmatrix} 1 & 1 & 1\\ 1 & 1 & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 1 & 2 & 3 & 4 & 5 \\ 1 & 2 & 3 & 4 & 5 \end{bmatrix} \end{gathered}=\begin{bmatrix} 3 & 6 & 9 & 12 & 15 \\ 3 & 6 & 9 & 12 & 15 \end{bmatrix} [111111]×⎣⎡111222333444555⎦⎤=[33669912121515]
其计算方法为:
1 ∗ 1 + 1 ∗ 1 + 1 ∗ 1 = a 11 = 3 , 1 ∗ 2 + 1 ∗ 2 + 1 ∗ 2 = a 12 = 6 … … 1*1+1*1+1*1=a11=3, \, \,\,1*2+1*2+1*2=a12=6…… 1∗1+1∗1+1∗1=a11=3,1∗2+1∗2+1∗2=a12=6……
其中a11为第一行第一个元素,以此类推
2、矩阵内积
矩阵点法也就是我们常说的矩阵点乘
即矩阵的对应元素相乘,故它要求两个矩阵形状一样,下面 2 ∗ 3 2*3 2∗3矩阵与 2 ∗ 3 2*3 2∗3矩阵为例:
[ 1 1 1 1 1 1 ] . [ 1 2 3 4 5 6 ] = [ 1 2 3 4 5 6 ] \begin{gathered} \begin{bmatrix} 1 & 1 & 1\\ 1 & 1 & 1 \end{bmatrix} . \begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \end{gathered}=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} [111111].[142536]=[142536]
在进入正题之前,先扯点儿闲篇——大家应该都知道numpy(至少听说过,python的一个数值计算库,pytorch不火的时候,numpy还是很好用的),而pytorch,主要特点是可以使用GPU加速运算,但是计算上和numpy有很多类似之处,那好,介绍pytorch的矩阵乘法之前,先说说numpy中ndarray中矩阵的乘法:
numpyt中点乘使用*或者np.multiply(),而叉乘使用@, np.dot(), np.matmul()
测试测序如下:
import numpy as np
print("numpy")
A = np.array([[1, 2, 3, 6], [2, 3, 4, 3], [2, 3, 4, 4]])
B = np.array([[1, 0, 1, 4], [2, 1, -1, 0], [2, 1, 5, 0]])
C = np.array([[1, 0, 3], [0, 1, 2], [-1, 0, 1], [-1, 0, 1]])
# 对应位置相乘,点乘
print("矩阵对应元素相乘 点乘")
print("*运算符\n", A*B)
print("np.multiply\n", np.multiply(A, B))
print("矩阵相乘 叉乘")
print("A.dot\n", A.dot(C)) # 矩阵乘法
print("@运算符\n", A@C)
print("np.matmul\n", np.matmul(A, C), '\n')
请移步numpy中dot和matmul的区别
而pytorch中用法略有不同,其中点乘使用*或者np.mul(),而叉乘使用@, torch.mm(), torch.matmul()(注意这里没有dot函数,使用torch.mm函数)
import torch
print("pytorch")
a = torch.ones(2, 3)
c = torch.FloatTensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
#b = torch.randint(1, 9, (2, 3))
b = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
print("矩阵对应元素相乘 点乘")
print("*运算符\n", a * b)
print("torch.mul\n", torch.mul(a, b))
print("矩阵相乘 叉乘")
print("@运算符\n", a@c)
print("torch.mm\n", torch.mm(a, c))
print("torch.matmul\n", torch.matmul(a, c), "\n")
输出结果:
下面又有一个问题,torch.mm()和torch.matmul()到底有什么区别?
可以参考官网教程
如果你懒得看,你可以看下面这两张我从官网上截的图
当然了,如果还是难以理解的话,请移步这里。
参考:
[1]https://blog.csdn.net/She_Said/article/details/98034841
[2]https://www.jb51.net/article/177406.htm
[3]https://pytorch.org/docs/stable/torch.html