torch.einsum()用法举例

import torch
torch.manual_seed(2)
x = torch.randn(1, 2, 3,4)
A = torch.randn(3, 2)
print("x:",x,'\n',"A.T:",A.T)
'''
x: tensor([[[[-1.0408,  0.9166, -1.3042, -1.1097],
             [-1.2188,  1.1676, -1.0574, -0.1188],
             [-0.8110,  0.6737, -1.1233, -0.0919]],

            [[-0.1320, -0.2751, -0.2350,  0.0937],
             [-0.7396, -1.2425, -0.1752,  0.6990],
             [-0.6861,  0.7202,  0.1963,  0.6142]]]])
 A.T: tensor([[-0.0591,  0.4258, -0.4766],
              [-1.5653, -1.4818,  0.2480]])
'''
x = torch.einsum('ncvl,vw->ncwl', x, A)
print('ncvl,vw->ncwl:\n',x)
'''
 tensor([[[[-7.0922e-02,  1.2192e-01,  1.6221e-01,  5.8812e-02],
           [ 3.2339e+00, -2.9978e+00,  3.3296e+00,  1.8902e+00]],

          [[ 1.9895e-02, -8.5600e-01, -1.5428e-01, -6.1229e-04],
           [ 1.1324e+00,  2.4502e+00,  6.7620e-01, -1.0301e+00]]]])
'''
# tensor(-7.0922e-02), (-1.0408)* (-0.0591) + (-1.2188) * 0.4258 + (-0.8110) * (-0.4766) = -0.07093116000000005
print(torch.einsum(A.T[0,:]*x[0,0,:,0]))
# tensor(1.2192e-01), 0.9166* (-0.0591) + 1.1676 * 0.4258 + 0.6737 * (-0.4766) = 0.1219076
print(torch.einsum(A.T[0,:]*x[0,0,:,1]))

参考:Rachel~Liu

你可能感兴趣的:(深度学习)