转载自:pytorch中的torch.argmax函数 - 知乎
x = torch.randn(3, 5)
print(x)
print(torch.argmax(x))
print(torch.argmax(x, dim=0))
print(torch.argmax(x, dim=-2))
print(torch.argmax(x, dim=1))
print(torch.argmax(x, dim=-1))
output:
tensor([[-1.0214, 0.7577, -0.0481, -1.0252, 0.9443],
[ 0.5071, -1.6073, -0.6960, -0.6066, 1.6297],
[-0.2776, -1.3551, 0.0036, -0.9210, -0.6517]])
tensor(9)
tensor([1, 0, 2, 1, 1])
tensor([1, 0, 2, 1, 1])
tensor([4, 4, 2])
tensor([4, 4, 2])
结论:dim的取值为[-2, 1]之间,只能取整,有四个数,0和-2对应,得到的是每一列的最大值,1和-1对应,得到的是每一行的最大值。如果参数中不写dim,则得到的是张量中最大的值对应的索引(从0开始)。
注意:
(1)就是dim等于几,就是表明删除那一维。比如x = torch.randn(3, 5),三行五列的矩阵,就是当dim=1时,就是说明结果是3个数字了,dim=0时,就是结果为5个数字
(2)此外不指定dim维度的时候,是直接按照顺序对所有元素进行遍历。返回的索引是基于所有维度铺平时的索引,比如上面:
print(torch.argmax(x))
结果就是9,这个就是遍历下面的矩阵:
tensor([[-1.0214, 0.7577, -0.0481, -1.0252, 0.9443],
[ 0.5071, -1.6073, -0.6960, -0.6066, 1.6297],
[-0.2776, -1.3551, 0.0036, -0.9210, -0.6517]])
1.6297这个元素就是最大的,其索引等于平铺后的15个元素中的第8个,即索引为9