python 和 pytorch中的矩阵乘法

python 矩阵有两种形式:array 和 matrix 对象,
本篇文章主要介绍 np.array() 和 np.mat() 这两种类型矩阵做乘法时的异同
这两种数据类型均有三种操作方式:

  • 乘号 *
  • np.dot() 或 @
  • np.multiply()

np.array() 类型

a = np.array([[1., 2.], [3., 4.]])
b = np.array([[1., 2.], [3., 4.]])

# 对应位置元素做乘法
c_1 = a * b
[[ 1.  4.]
 [ 9. 16.]]
 
# 线性代数中常规的矩阵乘法
c_2 = np.dot(a, b)	# 等价于 a @ b
[[ 7. 10.]
 [15. 22.]]
 
# 对应位置元素做乘法
c_3 = np.multiply(a, b)
[[ 1.  4.]
 [ 9. 16.]]

np.mat() 类型

a = np.mat([[1., 2.], [3., 4.]])
b = np.mat([[1., 2.], [3., 4.]])

# 线性代数中常规的矩阵乘法
c_1 = a * b
[[ 7. 10.]
 [15. 22.]]
 
# 线性代数中常规的矩阵乘法
c_2 = np.dot(a, b)
[[ 7. 10.]
 [15. 22.]]
 
# 对应位置元素做乘法
c_3 = np.multiply(a, b)
[[ 1.  4.]
 [ 9. 16.]]

通过上面代码可以知道:
( 1 ) np.dot() 或 @ 对于两种数据格式均为矩阵乘法
( 2 ) np.multiple() 对于两种数据格式均为按对应位置元素相乘
( 3 ) * 对于 array 是按元素相乘,对于mat是做矩阵乘法

pytorch中的矩阵乘法

( 1 ) torch.mul(a, b)
可以广播, 是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵;
( 2 ) torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵。只能处理二维矩阵
( 2 ) torch.matmul()
能处理batch、广播的矩阵

a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([[1., 2.], [3., 4.]])

c_1 = torch.mm(a, b)
tensor([[ 7., 10.],
        [15., 22.]])
        
c_2 = torch.mul(a, b)
tensor([[ 1.,  4.],
        [ 9., 16.]])

torch.matmul()
等价于使用 @
当输入是二维时,和 torch.mm 函数用法相同
当输入是多维时,把多出的一维作为batch提出来,其他部分做矩阵乘法

a = torch.ones((5, 3, 4))
b = torch.ones((4, 2))

c = torch.matmul(a, b)
print(c.shape)

torch.Size([5, 3, 2])

下面看一个两个都是三维的例子

a = torch.ones((2, 3, 4))
b = torch.ones((1, 4, 2))

c = torch.matmul(a, b)
print(c.shape)

torch.Size([2, 3, 2])

再看一个复杂点的例子

a = torch.ones((2, 1, 3, 4))
b = torch.ones((1, 4, 2))

c = torch.matmul(a, b)	# 等价于 c = a @ b

print(c.shape)

torch.Size([2, 1, 3, 2])

首先把a的第0维2作为batch提出来,则a和b都可看作三维。再把a的1broadcat成5,提取公因式5。(这样说虽然不严谨,但是便于理解。)然后a剩下(3,4),b剩下(4,2),做矩阵乘法得到(3,2)。

参考:
torch.matmul()用法介绍
np.mat()、np.matrix()、np.array()函数解析(最清晰的解释)

你可能感兴趣的:(pytorch,python,python,矩阵,pytorch)