【torch.max()函数】predic = torch.max(outputs.data, 1)[1].cpu().numpy()

torch.max(input, dim, keepdim=False, out=None)

按维度dim 返回最大值以及最大值的索引。

dim = 0 表示按列求最大值,并返回其引
dim = 1 表示按行求最大值,并返回其索引

_, predicted = torch.max(outputs.data, 1)

torch.max()函数返回两个值,一个是具体的值,也就是预测概率,另一个是值对应的索引,即预测类别;这两个值分别用_,predidcted表示。

predic = torch.max(outputs.data, 1)[1].cpu().numpy()

troch.max()[1]:只返回最大值的索引

.numpy() :把数据转化为ndarray,即N维数组对象

你可能感兴趣的:(pytorch)