tensor.max方法

pred = torch.tensor([[0.7,0,0.2,0.1,0],[0,0.2,0.4,0.3,0.1]])
pred.max(1, keepdim=True)[1]

输出:

 

pred.max(1, keepdim=False)[1]

输出:

可以用在softmax后取出最大概率所在的索引

tensor.max(1,keepdim)[1]:返回索引

tensor.max(1,keepdim)[0]:返回最大值

1指第一维度,第0维一般为batchsize

你可能感兴趣的:(pytorch,pytorch,深度学习,机器学习)