pytorch中利用gather函数取出标签的预测概率

在pytorch中神经网络训练输出为one-hot编码,假设y_h为2个样本,在3个类别的情况下的输出:

y_h= torch.tensor([[0.1, 0.2, 0.7], [0.2, 0.2, 0.6]])

y_lable=torch.LongTensor([0, 2])

out=y_h.gather(1, y_lable.view(-1, 1))

输出结果   tensor([[0.1000], [0.6000]])

你可能感兴趣的:(pytorch,深度学习,机器学习)