官网链接:TORCH.ARGMAX
torch.argmax(input) → LongTensor
返回输入张量 input 所有元素中的最大值的下标(如果有多个最大值,则返回第一个最大值的索引)。
import torch
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a))
输出结果:
tensor([[-0.7018, 1.1887, -0.2344, 0.3216],
[ 1.3548, -0.8575, -1.0585, -0.3462],
[ 0.5845, 0.2345, 1.6444, 1.1129],
[-1.1226, -0.5765, -0.4906, 0.0132]])
tensor(10)
在所有的元素中,第11个元素 1.6444 最大,其索引是 10 ,因此返回 tensor(10)。
torch.argmax(input, dim=None, keepdim=False)
import torch
a = torch.tensor(
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
])
print(a.shape)
b = torch.argmax(a, dim=0) # 压缩行,返回列最大值的序号
print(b)
print(b.shape)
输出结果:
torch.Size([3, 4])
tensor([1, 2, 0, 1])
torch.Size([4])
指定的维度是 0 ,也就是行,要压缩行,就要找列的最大值。
从 [3, 4] -> [4],可见第一个维度 3 消失了。
import torch
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1)) #压缩列,返回行最大值的序号
输出结果:
tensor([[-1.3736, 0.8958, -0.6470, 1.3395],
[-0.4279, 0.0682, 0.7635, 1.1857],
[ 1.7861, -0.6515, -0.5456, -0.3066],
[ 1.1898, -0.0208, -0.3662, 0.1799]])
tensor([3, 3, 0, 0])
指定的维度是 1 ,也就是列,要压缩列,就要找行的最大值。
import torch
a = torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
print(a.shape)
b = torch.argmax(a, dim=0)
print(b)
print(b.shape)
输出结果:
torch.Size([2, 3, 4])
tensor([[0, 1, 0, 0],
[0, 1, 0, 0],
[1, 0, 1, 0]])
torch.Size([3, 4])
从 [2, 3, 4] -> [3, 4],可见第一个维度 2 消失了。
import torch
a = torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
print(a.shape)
b = torch.argmax(a, dim=0, keepdim=True)
print(b)
print(b.shape)
输出结果:
torch.Size([2, 3, 4])
tensor([[[0, 1, 0, 0],
[0, 1, 0, 0],
[1, 0, 1, 0]]])
torch.Size([1, 3, 4])
与实例2的不同之处:加了 keepdim=True 参数,输出从 [3, 4] -> [1, 3, 4],保留了被压缩的第一维,只不过从 2 变成了压缩后的 1 。