代码:
outputs = model(inputs)
_, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
outputs是训练中,模型的输出值,是以tensor的形式存在,使用torch.max()函数,将其转换为预测值、预测值索引(看一下预测出的是哪些交互),后续方便与labels对比,从而求出准确率等指标。
values, indices = torch.max(input, dim)
torch.max()[0].data # 只返回最大值的每个数
troch.max()[1].data # 只返回最大值的每个索引
input
是上述模型的输出tensor,
dim
是维度,比如dim=0,得到的tensor形状会去掉第0维的维度。
例子如下:
import torch
a = torch.tensor([[[1,5,62,54], [2,6,2,6], [2,65,2,6]],[[1,0,0,0],[2,0,0,0],[3,1,1,99]]])
a是一个tensor:
tensor([[[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]],
[[ 1, 0, 0, 0],
[ 2, 0, 0, 0],
[ 3, 1, 1, 99]]])
当dim分别为0,1,2时的结果如下,其中values是值,indices是对应值的索引。
>>> torch.max(a,0) # 得到整个矩阵对应位置的最大值
torch.return_types.max(
values=tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 3, 65, 2, 99]]),
indices=tensor([[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]]))
>>> torch.max(a,1) # 得到每列的最大值
torch.return_types.max(
values=tensor([[ 2, 65, 62, 54],
[ 3, 1, 1, 99]]),
indices=tensor([[1, 2, 0, 0],
[2, 2, 2, 2]]))
>>> torch.max(a,2) # 得到每行的最大值
torch.return_types.max(
values=tensor([[62, 6, 65],
[ 1, 2, 99]]),
indices=tensor([[2, 1, 1],
[0, 0, 3]]))
输出得到结果的shape,可以发现规律,当dim为0时,得到的结果为去掉第0维之后的维度。
>>> a.shape
torch.Size([2, 3, 4])
>>> max_0 = torch.max(a,0)
>>> max_0.values.shape
torch.Size([3, 4])
>>> max_1 = torch.max(a,1)
>>> max_1.values.shape
torch.Size([2, 4])
>>> max_2 = torch.max(a,2)
>>> max_2.values.shape
torch.Size([2, 3])
torch.argmax()函数
另外还有一个类似的函数,同样可以求最大值,区别在于返回值不再像torch.max()有两个返回值:值,索引,torch.argmax()的返回值只有一个:索引。
参考: