torch.mul、torch.mm、torch.bmm、torch.matmul的区别

torch.mul

torch.mul(input, other, out=None)

功能

对位相乘,可以广播

该函数能处理两种情况

  1. input是矩阵/向量,other是标量
    这个时候是就是input的所有元素乘上other
  2. input是矩阵/向量,other是矩阵/向量
    这时 o u t i = i n p u t i × o t h e r i out_i = input_i \times other_i outi=inputi×otheri,对位相乘,如果两个都是向量,则可以广播的

例子

  1. input和other的size相同的对位相乘

    a: tensor([[ 1.8351,  2.1536],
        [-0.8320, -1.4578]])
    b: tensor([[2.9355, 0.3450],
        [0.5708, 1.9957]])
    c = torch.mul(a,b):
     tensor([[ 5.3869,  0.7429],
        [-0.4749, -2.9093]])
    
  2. 两个向量的广播

    a: tensor([[ 1.8351,  2.1536],
            [-0.8320, -1.4578]])
    b: tensor([[2.9355, 0.3450],
            [0.5708, 1.9957]])
    c = torch.mul(a,b):
     tensor([[ 5.3869,  0.7429],
            [-0.4749, -2.9093]])
    

torch.mm

torch.mm(input, mat2, out=None)

解决的问题

处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul

例子

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

torch.bmm

torch.bmm(input, mat2, out=None)

看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播


torch.matmul

torch.matmul(input, other, out=None)

功能
适用性最多的,能处理batch、广播的矩阵:

  1. 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
  2. 如果第一个是二维,第二个是一维,就是矩阵乘向量
  3. 带有batch的情况,可保留batch计算
  4. 维度不同时,可先广播,再batch计算

例子

  1. vector x vector

    tensor1 = torch.randn(3)
    tensor2 = torch.randn(3)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([])
    
  2. matrix x vector

    tensor1 = torch.randn(3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([3])
    
  3. batched matrix x broadcasted vecto

    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([10, 3])
    
  4. batched matrix x batched matrix

    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(10, 4, 5)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([10, 3, 5])
    

总结

对位相乘用torch.mul,二维矩阵乘法用torch.mm,batch二维矩阵用torch.bmm,batch、广播用torch.matmul

你可能感兴趣的:(Pytorch)