全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维;
torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的矩阵
** mat2* (Tensor) – – 第二个要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;
torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一批要相乘的矩阵
** mat2* (Tensor) – – 第二批要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。
input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()
torch.Size([10, 3, 5])
可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数。
torch.matmul(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的张量
** mat2* (Tensor) – – 第二个要相乘的张量
支持广播到通用形状、类型推广以及整数、浮点和复杂输入。
(1)若两个都是1D(向量)的,则返回两个向量的点积;
(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;
(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;
(4)若input是2D,other是1D,则返回两者的点积结果;
(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)
matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算。
代码如下:
import torch
import numpy as np
np.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(a)
print("=============================================")
print(b)
print("=============================================")
print(c.size())
print("=============================================")
print(c)
运行结果为:
tensor([[[[1, 0, 1, 0],
[1, 1, 0, 1],
[0, 0, 0, 0]],
[[1, 1, 1, 1],
[1, 1, 0, 0],
[0, 1, 0, 1]]],
[[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 1, 0, 0]],
[[1, 1, 1, 1],
[1, 1, 1, 1],
[0, 0, 0, 0]]]], dtype=torch.int32)
=============================================
tensor([[[[0, 1, 0],
[1, 1, 0],
[0, 0, 0],
[1, 1, 0]]],
[[[0, 1, 0],
[1, 1, 1],
[1, 1, 1],
[1, 0, 1]]]], dtype=torch.int32)
=============================================
torch.Size([2, 2, 3, 3])
=============================================
tensor([[[[0, 1, 0],
[2, 3, 0],
[0, 0, 0]],
[[2, 3, 0],
[1, 2, 0],
[2, 2, 0]]],
[[[1, 0, 1],
[1, 0, 1],
[1, 1, 1]],
[[3, 3, 3],
[3, 3, 3],
[0, 0, 0]]]], dtype=torch.int32)
部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 官方文档查询地址
https://pytorch.org/docs/stable/index.html
参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
https://blog.csdn.net/irober/article/details/113686080