torch.matmul和torch.mm和*区别

a * b,要求两个矩阵维度完全一致,即两个矩阵对应元素相乘,输出的维度也和原矩阵维度相同:

import torch
a = torch.randn(2,3)
b = torch.randn(2,3)
print(a)
print(b)
c = a * b
print(c.size())
print(c)

结果:

tensor([[-0.6309, -1.4365, -0.8549],
        [ 1.1856,  0.8986,  1.6343]])
tensor([[-0.2452, -1.0548,  0.9490],
        [-1.7368,  1.2816, -0.2548]])
(2, 3)
tensor([[ 0.1547,  1.5153, -0.8113],
        [-2.0591,  1.1516, -0.4163]])

torch.mm(a,b),要求两个矩阵维度是(n×m)和(m×p),即普通二维矩阵乘法

import torch
a = torch.randn(2,3)
b = torch.randn(3,2)
print(a)
print(b)

c = torch.mm(a,b)
print(c.size())
print(c)

结果:

tensor([[ 1.0326, -0.8577,  0.9282],
        [-0.1615,  2.2162,  1.1355]])
tensor([[-0.2482, -0.7289],
        [-1.4625,  0.9330],
        [-0.6411, -1.5660]])
(2, 2)
tensor([[ 0.4031, -3.0065],
        [-3.9292,  0.4072]])

torch.matul(a,b),matmul可以进行张量乘法,输入可以是高维,当输入是多维时,把多出的一维作为batch提出来,其他部分做矩阵乘法

import torch
a = torch.randn(5,3,4)
b = torch.randn(4,2)
print(a)
print(b)

c = torch.matmul(a,b)
print(c.size())
print(c)

结果:

tensor([[[-0.6082, -0.4990,  0.5180,  0.4803],
         [ 0.4278, -1.4530,  0.7533,  2.0362],
         [-0.7447, -0.4938,  0.9194,  0.7219]],

        [[-0.9034,  3.1070, -0.0789, -0.1922],
         [-0.7752, -0.0700, -1.2752, -0.1430],
         [-0.1135,  0.4542,  1.5085, -0.8310]],

        [[-0.4306, -0.7926,  0.0455, -1.4027],
         [ 1.8869, -1.3122, -0.1038,  1.8231],
         [-0.6760, -0.4332, -0.9828, -1.0844]],

        [[-0.9391,  0.0992, -0.3816,  1.7813],
         [ 0.4111, -1.3645, -1.2178,  0.9432],
         [ 1.4839,  0.6080, -0.3512, -0.9438]],

        [[-0.8659, -0.4728, -0.6430,  0.7306],
         [ 2.1543, -0.0352,  0.3309, -0.8087],
         [ 2.3823,  0.4962, -0.0237, -0.0065]]])
tensor([[ 0.2750,  0.5015],
        [-0.5091, -0.2192],
        [-0.1881, -0.3595],
        [-1.0113, -0.1662]])
(5, 3, 2)
tensor([[[-0.4963, -0.4617],
         [-1.3434, -0.0761],
         [-0.8563, -0.7157]],

        [[-1.6210, -1.0739],
         [ 0.2070,  0.1088],
         [ 0.2941, -0.5607]],

        [[ 1.6951,  0.1746],
         [-0.6372,  0.9683],
         [ 1.3162,  0.2895]],

        [[-2.0383, -0.6516],
         [ 0.0830,  0.7864],
         [ 1.1189,  0.8940]],

        [[-0.6152, -0.2209],
         [ 1.3658,  1.1036],
         [ 0.4134,  1.0956]]])
import torch
a = torch.randn(2,5,3)
b = torch.randn(1,3,4)
print(a)
print(b)

c = torch.matmul(a,b)
print(c.size())
print(c)

结果:

tensor([[[ 1.3973, -0.6357, -1.1457],
         [-1.0271, -0.0732,  0.8141],
         [ 0.0373,  0.2833,  0.1024],
         [-0.0090, -1.4797,  0.1669],
         [ 1.2712, -0.0930,  1.0184]],

        [[-2.2818, -0.6199,  0.5868],
         [ 0.8801,  0.7427,  1.1135],
         [ 0.7952, -0.1778,  1.2183],
         [ 0.8849, -0.2708, -0.2155],
         [-0.1221,  0.5816,  3.0285]]])
tensor([[[ 0.0745, -0.6830, -0.6044,  0.0647],
         [-1.5686,  0.3620,  0.8491,  1.0607],
         [ 1.3519,  0.7895,  1.4727,  0.8019]]])
(2, 5, 4)
tensor([[[-0.4477, -2.0890, -3.0716, -1.5026],
         [ 1.1389,  1.3178,  1.7576,  0.5088],
         [-0.3032,  0.1579,  0.3689,  0.3851],
         [ 2.5459, -0.3977, -1.0052, -1.4362],
         [ 1.6174, -0.0979,  0.6526,  0.8003]],

        [[ 1.5957,  1.7975,  1.7169, -0.3346],
         [ 0.4058,  0.5468,  1.7386,  1.7377],
         [ 1.9851,  0.3543,  1.1625,  0.8398],
         [ 0.1993, -0.8726, -1.0822, -0.4028],
         [ 3.1728,  2.6848,  5.0278,  3.0376]]])

 

总结:

1. a * b 要求两个矩阵输入维度一致,即矩阵对应元素相乘

2. 当输入是二维矩阵时,torch.mm(a,b)和torch.matul(a,b)是一样的

3. torch.matul(a,b)可计算高维矩阵相乘,此时,把多出的一维作为batch提出来,其他部分做矩阵乘法

你可能感兴趣的:(pytorch)