torch.mul | torch.mm | torch.bmm | torch.matmul的区别和使用

目录

    • torch.mul
    • torch.mm
    • torch.bmm
    • torch.matmul
    • 总结

torch.mul

用法

torch.mul(input1, input2, out=None) #对位元素相乘

功能
1、当 input1是矩阵/向量 和 input2 是矩阵/向量,则对位相乘
2、当 input1 是矩阵/向量,input2 是标量,则 input1 的所有元素乘上input2
3、当 input1 是矩阵/向量,input2 是向量/矩阵,则 input2 / input1 先进行广播,再对位相乘

举例
#1 input1 和input2 的size相同,对位相乘

a = torch.tensor([[ 1.8351,  2.1536],
   		          [-0.8320, -1.4578]])
b = torch.tensor([[2.9355, 0.3450],
    	          [2.9355, 0.3450]])
c = torch.mul(a,b)
tensor([[ 5.3869,  0.7429],
        [-2.4423, -0.5029]])

#3 input1 是矩阵,input2 是向量,则input2 先进行广播,再对位相乘

a = torch.tensor([[ 1.8351,  2.1536],
   		          [-0.8320, -1.4578]])
b = torch.tensor([2.9355, 0.3450])
c = torch.mul(a,b)
tensor([[ 5.3869,  0.7429],
        [-2.4423, -0.5029]])

可见,#1 与 #3 结果相同。

torch.mm

用法

torch.mm(input1, input2, out=None) #二维矩阵相乘

功能
处理二维矩阵的乘法 (a, b) × (b, d) = (a, d),而且也只能处理二维矩阵,其他维度要用torch.matmul
举例

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
mat3 = torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

torch.bmm

用法

torch.bmm(mat1, mat2, out=None)#三维矩阵,第一维是batch_size,后两维进行矩阵的乘法

功能
看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播
举例

mat1 = torch.randn(6, 2, 3)#batch_size = 6
mat2 = torch.randn(6, 3, 4)
mat3 = torch.bmm(mat1, mat2)#torch.Size([6, 2, 4])
tensor([[[-5.5404e-02, -1.2719e+00, -1.3952e+00,  7.2475e-01],
         [ 1.0943e+00,  2.1826e+00, -4.4239e-01, -1.0643e+00]],

        [[ 1.1785e+00, -4.9125e-01, -3.4894e-01, -2.1170e-02],
         [-6.4008e-01, -2.4427e-03, -3.1276e-01, -4.5647e-01]],

        [[-2.9938e-01,  7.6840e-01, -2.7852e-01,  5.4946e-01],
         [ 4.2854e-01,  1.8301e+00,  1.7477e-02, -1.4107e+00]],

        [[-2.7399e-01,  1.2810e+00,  1.8456e+00, -5.5862e-01],
         [ 1.0337e+00,  1.3213e+00,  7.3194e-01,  3.9463e-01]],

        [[-1.3685e-01, -9.7863e-02, -3.3586e-01,  1.9415e-01],
         [-3.7319e+00, -1.0287e+00, -2.8267e+00,  1.6140e+00]],

        [[-2.6132e+00,  1.2601e+00,  2.4735e+00, -5.1219e-01],
         [-3.9365e+00,  1.1015e+00,  5.8874e-01,  3.0009e-01]]])

torch.matmul

用法

torch.matmul(input1, input2, out=None)#适用性最多,能处理batch、广播等

功能
1、适用性最多的,能处理batch、广播的矩阵乘法
2、input1 是一维,input2 是二维,那么给input1 提供一个维度(相当于 input1.unsqueeze(0)),再进行向量乘矩阵
3、带有batch的情况,可保留batch计算
4、维度不同时,可先广播,再batch计算
举例
1、vector x vector

a = torch.randn(3)
b = torch.randn(3)
c = torch.matmul(a, b)
print(c, c.size())
#
tensor(1.2123) torch.Size([])

2、matrix x vector

a = torch.randn(3, 4)
b = torch.randn(4)
c = torch.matmul(a, b)
print(c.size())
#
torch.Size([3])

3、3dmatrix x 2dmatrix / vector(broadcasted)

a = torch.randn(10, 3, 4)
b = torch.randn(4, 6)
c = torch.randn(4)

d = torch.matmul(a, b)
e = torch.matmul(a, c)
print(c.size(), e.size())
#
torch.Size([10, 3, 6]) torch.Size([10, 3])

4、3dmatrix x 3dmatrix
此时与torch.bmm等效

a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 6)
c = torch.matmul(a, b)
d = torch.bmm(a, b)
print(c == d)
print(c.size())
#
tensor([[[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]],

        [[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 3, 6])

总结

对位相乘torch.mul

二维矩阵乘法torch.mm

batch三维矩阵乘法torch.bmm

batch、广播矩阵乘法torch.matmul

参考Blog

你可能感兴趣的:(Bug,Python,pytorch,深度学习)