官网例子
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
>>> torch.argmax(a, dim=1)
tensor([ 0, 2, 0, 1])
再给出一些例子:
input=torch.FloatTensor([1,3,1,8,0])
output=torch.argmax(input,dim=0)
print('input shape ',input.shape)
print('output shape ',output.shape)
print(output)
程序输出:
input shape torch.Size([5])
output shape torch.Size([])
tensor(3)
input = torch.FloatTensor([[[1,0],
[0,0]],
[[2,2],
[2,6]],
[[3,7],
[3,3]],
[[9,9],
[0,0]]])
output=torch.argmax(input,dim=0)
print('input shape ',input.shape)
print('output shape ',output.shape)
print(output)
程序输出:
input shape torch.Size([4, 2, 2])
output shape torch.Size([2, 2])
tensor([[3, 3],
[2, 1]])
首先理解输出输出的维度变化,
input shape 4,2,2
argmax(input,dim=0)
output的维度就会少了第零维,变成了 (2,2)
如图所示,四根斜线代表在这4个地方取一个最大值,就是斜线穿过的四个点取一个最大值。
最后就生成了四个值,形状是(2,2)