torch.max(input, dim, keepdim=False) → output tensors (max, max_indices)
输入参数:
input =
输入tensor
dim =
求最大值的维度
keepdim =
是否保持原维度大小输出
输出:
max =
指定维度求得的最大值
max_indices =
指定维度求得的最大值索引
下面以一个大小为(3, 2, 5)的张量为例:
当dim = 0
时
import torch
x = torch.rand(3, 2, 5) # 生成随机数
print(x)
>>> tensor([
[[0.2514, 0.7950, 0.9641, 0.0135, 0.2785],
[0.2575, 0.4410, 0.6829, 0.6668, 0.5850]],
[[0.4725, 0.2015, 0.3406, 0.6989, 0.3551],
[0.9674, 0.5781, 0.6250, 0.3404, 0.4238]],
[[0.2377, 0.3673, 0.3647, 0.1027, 0.9024],
[0.0047, 0.0106, 0.4600, 0.6851, 0.7389]]])
x_value_index = torch.max(x, dim=0, keepdim=True) # 最大值和对应索引
print(x_value_index)
>>> torch.return_types.max(
values=tensor([[
[0.4725, 0.7950, 0.9641, 0.6989, 0.9024],
[0.9674, 0.5781, 0.6829, 0.6851, 0.7389]
]]),
indices=tensor([[
[1, 0, 0, 1, 2],
[1, 1, 0, 2, 2]
]])
)
x_value = torch.max(x, 2, keepdim=True)[0] # 单独取出最大值
print(x_value)
>>> tensor([[[0.9641],[0.6829]],[[0.6989],[0.9674]],[[0.9024],[0.7389]]])
x_index = torch.max(x, 2, keepdim=True)[1] #单独取出最大值索引
print(x_index)
>>> tensor([[[2],[2]], [[3],[0]],[[4],[4]]])
当dim = 1
时
import torch
x = torch.rand(3, 2, 5)
print(x)
>>> tensor([
[[0.5524, 0.1146, 0.4460, 0.4948, 0.7163],
[0.5388, 0.2290, 0.4652, 0.3818, 0.4202]],
[[0.4045, 0.5833, 0.7844, 0.5605, 0.6278],
[0.0335, 0.1204, 0.3604, 0.4386, 0.0286]],
[[0.9510, 0.7801, 0.2879, 0.0369, 0.8103],
[0.9522, 0.7442, 0.5938, 0.1807, 0.2721]]])
x_value_index = torch.max(x, dim=1, keepdim=True)
print(x_value_index)
>>> torch.return_types.max(
values=tensor([
[[0.5524, 0.2290, 0.4652, 0.4948, 0.7163]],
[[0.4045, 0.5833, 0.7844, 0.5605, 0.6278]],
[[0.9522, 0.7801, 0.5938, 0.1807, 0.8103]]
]),
indices=tensor([
[[0, 1, 1, 0, 0]],
[[0, 0, 0, 0, 0]],
[[1, 0, 1, 1, 0]]
])
)
x_value = torch.max(x, 2, keepdim=True)[0]
print(x_value)
>>> tensor([[[0.7163],[0.5388]],[[0.7844],[0.4386]],[[0.9510],[0.9522]]])
x_index = torch.max(x, 2, keepdim=True)[1]
print(x_index)
>>> tensor([[[4],[0]],[[2],[3]],[[0],[0]]])
当dim = 2
时
import torch
x = torch.rand(3, 2, 5) # 生成随机数
print(x)
>>>tensor([
[[0.9249, 0.5676, 0.1035, 0.3701, 0.4501],
[0.5440, 0.9992, 0.7398, 0.1513, 0.3889]],
[[0.2020, 0.4533, 0.1103, 0.9006, 0.8098],
[0.3390, 0.3230, 0.8531, 0.1718, 0.4343]],
[[0.9874, 0.2138, 0.0301, 0.9558, 0.8844],
[0.7317, 0.3344, 0.4552, 0.3196, 0.6343]]
])
x_value_index = torch.max(x, dim = 2, keepdim=True) # 取每一行的最大值
print(x_value_index)
>>> torch.return_types.max(
values=tensor([[[0.9249],[0.9992]],[[0.9006],[0.8531]],[[0.9874],[0.7317]]]),
indices=tensor([[[0],[1]],[[3],[2]],[[0],[0]]]))
x_value = torch.max(x, 2, keepdim=True)[0]
print(x_value)
>>>tensor([[[0.9249],[0.9992]],[[0.9006],[0.8531]],[[0.9874],[0.7317]]])
x_index = torch.max(x, 2, keepdim=True)[1]
print(x_index)
>>> tensor([[[0],[1]],[[3],[2]],[[0],[0]]])