torch.gather的使用及理解

结论:使用方法

# gather,沿dim指定的轴收集值。
y_hat.gather(1, y.view(-1, 1))# y.view(-1, 1)会变成一列,y_hat的取y作为的索引的值

分步理解:先创建一个2*3的tensor

>>y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])

为了使用gather函数,我们得创建一个tensor作为gather得参数

>>y = torch.LongTensor([0, 2])

tensor([0, 2])

我们需要把y变个形状

>>y.view(-1, 1)

tensor([[0],
        [2]])

先来看看使用得结果

>>y_hat.gather(1, y.view(-1, 1))
tensor([[0.1000],
        [0.5000]])

图解:

torch.gather的使用及理解_第1张图片

 

 

 

 

你可能感兴趣的:(深度学习,神经网络,pytorch,python)