【torch小知识点03】矩阵乘法总结

#【torch小知识点03】

2023.01.24

矩阵乘法

  • 点乘torch.mul(a,b)
  • 点积 torch.dot(a,b)
  • 二维矩阵乘法 torch.mm(a,b)
  • 三维矩阵乘法 torch.bmm(a,b)
  • 高维矩阵乘法 torch.matmul(a,b)

1. 点乘和torch.mul(a,b)

  • 点乘torch.mul(a,b):对应元素相乘
import torch

a = torch.randn(2,3)
b = torch.randn(2,1)
res1 = a * b
res2 = torch.mul(a, b)
print(res1,"\n",res2)
tensor([[-0.5612, -0.2754,  0.6309],
        [-0.0140,  0.3515,  0.7356]]) 
 tensor([[-0.5612, -0.2754,  0.6309],
        [-0.0140,  0.3515,  0.7356]])
  • 如果是 ( 2 × 3 × 4 ) (2\times 3\times 4) 2×3×4 ( 2 × 1 × 4 ) (2\times 1\times 4) 2×1×4这种情况,也可以相乘,结果是 ( 2 × 3 × 4 ) (2\times 3\times 4) 2×3×4
    用到了广播机制
a = torch.randn(2,3,4)
b = torch.randn(2,1,4)
res = a * b
res.shape
torch.Size([2, 3, 4])

2. 矩阵乘法 torch.mm(a,b)

  • 满足矩阵乘法:若 a ∈ R m × n , b ∈ R n × b a \in R^{m \times n},b \in R^{n \times b} aRm×n,bRn×b o u t p u t ∈ R m × d output \in R^{m \times d} outputRm×d

  • 只能都是是二维

tensorA_2x3 = torch.tensor(
    [[1,2,3],
     [3,2,1]]
)

tensorF_3x2 = torch.tensor(
    [[1,2],
     [3,4],
     [5,6]]
)

print(torch.mm(tensorA_2x3, tensorF_3x2))
tensor([[22, 28],
        [14, 20]])
#b是1维 报错
b = torch.randn(3)
print(torch.mm(tensorA_2x3, b))
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

Input In [9], in ()
      1 b = torch.randn(3)
----> 2 print(torch.mm(tensorA_2x3, b))


RuntimeError: mat2 must be a matrix

3 torch.bmm(a,b)

因为深度学习中大多数时候是一个batch_size一个batch_size的读入数据,pytorch对此也进行了实现,即torch.bmm。

  • A的维度是b * m * n,B的维度是b * n * p,则torch.bmm(A,B)的维度是b * m * p。即第一维不动,后面矩阵乘法
  • 此处A,B只能是三维的张量
a = torch.rand(3,5,2)
b = torch.rand(3,2,4)
torch.bmm(a,b).shape
torch.Size([3, 5, 4])

4 torch.dot(a,b)点积

  • 一定都是1维,且长度相等

  • 两向量相乘相加得到一个标量

  • torch.dot(A,B)==torch.sum(A*B)

# 两个一维的
a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 3, 4])
torch.dot(a,b)
tensor(20)
  • 若mat1是二维向量,mat2是一维向量,那么对应操作就是torch.mv()
    • 但只能是二维x一维 且二维的第二个和一维的一致
import torch
import numpy as np

#【4,3】和【3】可以
tensorA_2x3 = torch.arange(12).reshape(4,3)
tensorF_3 = torch.tensor([2, 3, 4])
print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3))
torch.mv: tensor([11, 38, 65, 92])
# 三维的不行 【4,1,3】和【3】不行

tensorA_2x3 = torch.arange(12).reshape(4,1,3)
tensorF_3 = torch.tensor([2, 3, 4])
print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3))
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

Input In [3], in ()
      3 tensorA_2x3 = torch.arange(12).reshape(4,1,3)
      4 tensorF_3 = torch.tensor([2, 3, 4])
----> 5 print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3))


RuntimeError: vector + matrix @ vector expected, got 1, 3, 1
# 长度不一致不行【3,4】和【3】报错
tensorA_2x3 = torch.arange(12).reshape(3,4)
tensorF_3 = torch.tensor([2, 3, 4])
print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3))
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

Input In [14], in ()
      2 tensorA_2x3 = torch.arange(12).reshape(3,4)
      3 tensorF_3 = torch.tensor([2, 3, 4])
----> 4 print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3))


RuntimeError: size mismatch, got 3, 3x4,3
# [4,3]和[3,2]不行
tensorA_2x3 = torch.arange(12).reshape(4,3)
tensorF_3x2 = torch.tensor(
    [[1,2],
     [3,4],
     [5,6]]
)
print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3x2))
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

Input In [13], in ()
      1 tensorA_2x3 = torch.arange(12).reshape(4,3)
      2 tensorF_3x2 = torch.tensor(
      3     [[1,2],
      4      [3,4],
      5      [5,6]]
      6 )
----> 7 print("torch.mv:",torch.mv(tensorA_2x3, tensorF_3x2))


RuntimeError: vector + matrix @ vector expected, got 1, 2, 2
tensorA_2x3 = torch.tensor(
    [[1,2,3],
     [3,2,1]]
)

tensorF_3x2 = torch.tensor(
    [[1,2],
     [3,4],
     [5,6]]
)

print(torch.mm(tensorA_2x3, tensorF_3x2))
tensor([[22, 28],
        [14, 20]])

5. torch.matmul(a,b)

5.1. 【一维 × 一维】

  • 两个一维张量的元素要求个数相同,相当于点乘
a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 3, 4])
torch.matmul(a, b)
tensor(20)
torch.dot(a,b)
tensor(20)

5.2. 【二维 × 二维】

A为2D向量,B为2D向量,A的维度为mn,B的维度为np。则torch.matmul(A,B)返回矩阵乘法结果

a = torch.rand(3,5)
b = torch.rand(5,3)
torch.matmul(a,b).shape
torch.Size([3, 3])

5.3. 【一维 × 二维】

第一个矩阵是1维,第二个是2维,行向量乘以矩阵。(线性代数矩阵和向量相乘),向量x矩阵相当于矩阵行向量的线性组合

a = torch.rand(3)
b = torch.rand(3,5)
torch.matmul(a,b).shape
torch.Size([5])

5.4. 【二维 × 一维】

第一个矩阵是2维,第二个是1维,矩阵乘以列向量。矩阵x向量,相当于矩阵列向量的线性组合

a = torch.rand(3,5)
b = torch.rand(5)
torch.matmul(a,b).shape
torch.Size([3])

5.5. 【高维 × 高维】

主要是看后两位,如果缺的 也是前面输出一样

a = torch.rand(3,2,4,5)
b = torch.rand(3,2,5,4)
torch.matmul(a,b).shape
out:
torch.Size([3, 2, 4, 4])

a = torch.rand(3,2,4,5)
b = torch.rand(2,5,4)
torch.matmul(a,b).shape
out:
torch.Size([3, 2, 4, 4])

# batch 行向量x矩阵 
a = torch.randn(3)
b = torch.randn(5,6,3,4)
torch.matmul(a,b).shape
out:
torch.Size([5, 6, 4])

# batch 矩阵x列向量
a = torch.randn(5,6,7,8)
b = torch.randn(8)
torch.matmul(a,b).shape
out:
torch.Size([5, 6, 7])

参考:
https://zhuanlan.zhihu.com/p/576893207
https://blog.csdn.net/whitesilence/article/details/119033117

你可能感兴趣的:(小知识点,torch,pytorch,矩阵乘法)