用法
torch.mul(input1, input2, out=None) #对位元素相乘
功能
1、当 input1是矩阵/向量 和 input2 是矩阵/向量,则对位相乘
2、当 input1 是矩阵/向量,input2 是标量,则 input1 的所有元素乘上input2
3、当 input1 是矩阵/向量,input2 是向量/矩阵,则 input2 / input1 先进行广播,再对位相乘
举例
#1 input1 和input2 的size相同,对位相乘
a = torch.tensor([[ 1.8351, 2.1536],
[-0.8320, -1.4578]])
b = torch.tensor([[2.9355, 0.3450],
[2.9355, 0.3450]])
c = torch.mul(a,b)
tensor([[ 5.3869, 0.7429],
[-2.4423, -0.5029]])
#3 input1 是矩阵,input2 是向量,则input2 先进行广播,再对位相乘
a = torch.tensor([[ 1.8351, 2.1536],
[-0.8320, -1.4578]])
b = torch.tensor([2.9355, 0.3450])
c = torch.mul(a,b)
tensor([[ 5.3869, 0.7429],
[-2.4423, -0.5029]])
可见,#1 与 #3 结果相同。
用法
torch.mm(input1, input2, out=None) #二维矩阵相乘
功能
处理二维矩阵的乘法 (a, b) × (b, d) = (a, d),而且也只能处理二维矩阵,其他维度要用torch.matmul
举例
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
mat3 = torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
用法
torch.bmm(mat1, mat2, out=None)#三维矩阵,第一维是batch_size,后两维进行矩阵的乘法
功能
看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播
举例
mat1 = torch.randn(6, 2, 3)#batch_size = 6
mat2 = torch.randn(6, 3, 4)
mat3 = torch.bmm(mat1, mat2)#torch.Size([6, 2, 4])
tensor([[[-5.5404e-02, -1.2719e+00, -1.3952e+00, 7.2475e-01],
[ 1.0943e+00, 2.1826e+00, -4.4239e-01, -1.0643e+00]],
[[ 1.1785e+00, -4.9125e-01, -3.4894e-01, -2.1170e-02],
[-6.4008e-01, -2.4427e-03, -3.1276e-01, -4.5647e-01]],
[[-2.9938e-01, 7.6840e-01, -2.7852e-01, 5.4946e-01],
[ 4.2854e-01, 1.8301e+00, 1.7477e-02, -1.4107e+00]],
[[-2.7399e-01, 1.2810e+00, 1.8456e+00, -5.5862e-01],
[ 1.0337e+00, 1.3213e+00, 7.3194e-01, 3.9463e-01]],
[[-1.3685e-01, -9.7863e-02, -3.3586e-01, 1.9415e-01],
[-3.7319e+00, -1.0287e+00, -2.8267e+00, 1.6140e+00]],
[[-2.6132e+00, 1.2601e+00, 2.4735e+00, -5.1219e-01],
[-3.9365e+00, 1.1015e+00, 5.8874e-01, 3.0009e-01]]])
用法
torch.matmul(input1, input2, out=None)#适用性最多,能处理batch、广播等
功能
1、适用性最多的,能处理batch、广播的矩阵乘法
2、input1 是一维,input2 是二维,那么给input1 提供一个维度(相当于 input1.unsqueeze(0)),再进行向量乘矩阵
3、带有batch的情况,可保留batch计算
4、维度不同时,可先广播,再batch计算
举例
1、vector x vector
a = torch.randn(3)
b = torch.randn(3)
c = torch.matmul(a, b)
print(c, c.size())
#
tensor(1.2123) torch.Size([])
2、matrix x vector
a = torch.randn(3, 4)
b = torch.randn(4)
c = torch.matmul(a, b)
print(c.size())
#
torch.Size([3])
3、3dmatrix x 2dmatrix / vector(broadcasted)
a = torch.randn(10, 3, 4)
b = torch.randn(4, 6)
c = torch.randn(4)
d = torch.matmul(a, b)
e = torch.matmul(a, c)
print(c.size(), e.size())
#
torch.Size([10, 3, 6]) torch.Size([10, 3])
4、3dmatrix x 3dmatrix
此时与torch.bmm等效
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 6)
c = torch.matmul(a, b)
d = torch.bmm(a, b)
print(c == d)
print(c.size())
#
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 3, 6])
对位相乘用torch.mul
二维矩阵乘法用torch.mm
batch三维矩阵乘法用torch.bmm
batch、广播矩阵乘法用torch.matmul
参考Blog