网上大多数对max的解释只停留在二维数据,在三维及以上就没有详述,我将对二维数据和三维数据进行详细解释,让你不再有疑虑
torch.max()使用讲解
torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
在分类问题中,通常使用max()函数对softmax函数的输出值进行操作,求出预测值索引
参数
输出
>>>import torch
>>>a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
>>>print(a)
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
torch.max(a,0)
torch.return_types.max(
values=tensor([ 2, 65, 62, 54]),
indices=tensor([1, 2, 0, 0]))
这个计算过程是:
torch.max(a, 1)
torch.return_types.max(
values=tensor([62, 6, 65]),
indices=tensor([2, 1, 1]))
这个计算过程是:
a = [1,2,13,4,5,6,27,8,9,0,11,12]
a = np.array(a).reshape(3,2,2)
a = torch.Tensor(a)
print(a)
tensor([[[ 1., 2.],
[13., 4.]],
[[ 5., 6.],
[27., 8.]],
[[ 9., 0.],
[11., 12.]]])
torch.max(a,dim=0)
torch.return_types.max(
values=tensor([[ 9., 6.],
[27., 12.]]),
indices=tensor([[2, 1],
[1, 2]]))
计算过程:
torch.max(a,dim=1)
torch.return_types.max(
values=tensor([[13., 4.],
[27., 8.],
[11., 12.]]),
indices=tensor([[1, 1],
[1, 1],
[1, 1]]))
计算过程:
torch.max(a,dim=2)
torch.return_types.max(
values=tensor([[ 2., 13.],
[ 6., 27.],
[ 9., 12.]]),
indices=tensor([[1, 0],
[1, 0],
[0, 1]]))
计算过程: