_, preds = torch.max(outputs.data, 1)

今天在看《PyTorch深度学习》这本书的时候,看到了一段代码,怎么都看不懂,然后CSDN上搜索了一下,发现了大佬的以篇博客《PyTorch系列 | _, predicted = torch.max(outputs.data, 1)的理解》,这里记录一下。

_, preds = torch.max(outputs.data, 1)

源代码如下:

# forward
 outputs = model(inputs)
 _, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)

torch.max()这个函数返回的是两个值:

  1. 第一个值是具体的value(我们用下划线_表示)
  2. 第二个值是value所在的index(也就是preds)。

数字1其实可以写为dim=1,这里简写为1,python也可以自动识别,dim=1表示输出所在行的最大值,若改写成dim=0则输出所在列的最大值。
比如说测试集有10个数据,那么训练好的网络将会预测这10个数据,得到一个10×2的矩阵(假设是二分类问题),比如说预测结果是下面这个矩阵。
_, preds = torch.max(outputs.data, 1)_第1张图片
那么,这个 下划线_ 表示的就是具体的value,也就是输出的最大值。那么为什么用 下划线_,可不可以用其他的变量名称来代替,比如x?答案自然是可以的。

你可能感兴趣的:(深度学习,pytorch)