torch.einsum 简单介绍计算流程

torch.einsum

>>> a = torch.arange(60.).reshape(5,3,4)
>>> b = torch.arange(24.).reshape(3,4,2)

>>> o = torch.einsum('fnd,ndh->fh', a, b)

>>> o
tensor([[1012., 1078.],
        [2596., 2806.],
        [4180., 4534.],
        [5764., 6262.],
        [7348., 7990.]])

>>>torch.matmul(a[0,:,:].flatten(),b[:,:,0].flatten())
tensor(1012.)
// the first element of the result of einsum

>>> torch.matmul(a[1,:,:].flatten(), b[:,:,0].flatten())
tensor(2596.)

你可能感兴趣的:(Bug,pytorch)