pytorch中的torch.argmax函数

转载自:pytorch中的torch.argmax函数 - 知乎

x = torch.randn(3, 5)
print(x)
print(torch.argmax(x))
print(torch.argmax(x, dim=0))
print(torch.argmax(x, dim=-2))
print(torch.argmax(x, dim=1))
print(torch.argmax(x, dim=-1))



output:
tensor([[-1.0214,  0.7577, -0.0481, -1.0252,  0.9443],
        [ 0.5071, -1.6073, -0.6960, -0.6066,  1.6297],
        [-0.2776, -1.3551,  0.0036, -0.9210, -0.6517]])
tensor(9)
tensor([1, 0, 2, 1, 1])
tensor([1, 0, 2, 1, 1])
tensor([4, 4, 2])
tensor([4, 4, 2])

 结论:dim的取值为[-2, 1]之间,只能取整,有四个数,0和-2对应,得到的是每一列的最大值,1和-1对应,得到的是每一行的最大值。如果参数中不写dim,则得到的是张量中最大的值对应的索引(从0开始)。

注意:

(1)就是dim等于几,就是表明删除那一维。比如x = torch.randn(3, 5),三行五列的矩阵,就是当dim=1时,就是说明结果是3个数字了,dim=0时,就是结果为5个数字

(2)此外不指定dim维度的时候,是直接按照顺序对所有元素进行遍历。返回的索引是基于所有维度铺平时的索引,比如上面:

print(torch.argmax(x))

结果就是9,这个就是遍历下面的矩阵:

tensor([[-1.0214,  0.7577, -0.0481, -1.0252,  0.9443],
        [ 0.5071, -1.6073, -0.6960, -0.6066,  1.6297],
        [-0.2776, -1.3551,  0.0036, -0.9210, -0.6517]])

1.6297这个元素就是最大的,其索引等于平铺后的15个元素中的第8个,即索引为9

你可能感兴趣的:(pytorch,pytorch,人工智能,python)