pytorch:torch.mm()和torch.matmul()

torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作

torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(m, x),b的维度是(x, n),返回的就是(m, n)的矩阵

相同:都可以来做矩阵相乘:

    a = torch.randn(2, 3)
    b = torch.randn(3, 2)
    print(torch.mm(a, b))
    print(torch.matmul(a, b))

区别:

matmul支持向量相乘,mm不支持。

    import torch

    x = torch.rand(2)
    y = torch.rand(2)
    print(torch.matmul(x, y))
    print(torch.mm(x, y)) #报错

你可能感兴趣的:(python基础)