关于argmax()

input.argmax(1):横向比较
input.argmax(0):纵向比较

import torch

input = torch.tensor([[0.6,0.5,0.3],
                      [0.2,0.3,0.1]])

target1 = torch.tensor([0,1])
target2 = torch.tensor([0,0,1])

# 参数为1,横向比较
output1 = input.argmax(1)
print(output1)      # torch.tensor([0,1])
print((output1 == target1).sum())   # 输出预测正确次数:tensor(2)

# 参数为0,纵向比较
output2 = input.argmax(0)
print(output2)      # torch.tensor([0,0,0])
print((output2 == target2).sum())   # 输出预测正确次数:tensor(2)

你可能感兴趣的:(pytorch学习,深度学习,python,pytorch)