import torch
torch.manual_seed(2)
x = torch.randn(1, 2, 3,4)
A = torch.randn(3, 2)
print("x:",x,'\n',"A.T:",A.T)
'''
x: tensor([[[[-1.0408, 0.9166, -1.3042, -1.1097],
[-1.2188, 1.1676, -1.0574, -0.1188],
[-0.8110, 0.6737, -1.1233, -0.0919]],
[[-0.1320, -0.2751, -0.2350, 0.0937],
[-0.7396, -1.2425, -0.1752, 0.6990],
[-0.6861, 0.7202, 0.1963, 0.6142]]]])
A.T: tensor([[-0.0591, 0.4258, -0.4766],
[-1.5653, -1.4818, 0.2480]])
'''
x = torch.einsum('ncvl,vw->ncwl', x, A)
print('ncvl,vw->ncwl:\n',x)
'''
tensor([[[[-7.0922e-02, 1.2192e-01, 1.6221e-01, 5.8812e-02],
[ 3.2339e+00, -2.9978e+00, 3.3296e+00, 1.8902e+00]],
[[ 1.9895e-02, -8.5600e-01, -1.5428e-01, -6.1229e-04],
[ 1.1324e+00, 2.4502e+00, 6.7620e-01, -1.0301e+00]]]])
'''
# tensor(-7.0922e-02), (-1.0408)* (-0.0591) + (-1.2188) * 0.4258 + (-0.8110) * (-0.4766) = -0.07093116000000005
print(torch.einsum(A.T[0,:]*x[0,0,:,0]))
# tensor(1.2192e-01), 0.9166* (-0.0591) + 1.1676 * 0.4258 + 0.6737 * (-0.4766) = 0.1219076
print(torch.einsum(A.T[0,:]*x[0,0,:,1]))
参考:Rachel~Liu