pytorch.max()的详细解释

网上大多数对max的解释只停留在二维数据,在三维及以上就没有详述,我将对二维数据和三维数据进行详细解释,让你不再有疑虑

参考文章

torch.max()使用讲解

torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)

在分类问题中,通常使用max()函数对softmax函数的输出值进行操作,求出预测值索引

参数

  • input:softmax函数输出的一个tensor
  • dim:是max函数索引的维度 0 0 0 1 1 1 0 0 0每列的最大值 1 1 1每行的最大值

输出

  • 函数会返回两个tensor,第一个tensor是每行的最大值,softmax的输出中最大的是1,索引第一个tensor是全1的tensor;第二个tensor是每行最大值的索引

二维数据详细讲述

>>>import torch
>>>a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
>>>print(a)

tensor([[ 1,  5, 62, 54],
        [ 2,  6,  2,  6],
        [ 2, 65,  2,  6]])

dim = 0

torch.max(a,0)
torch.return_types.max(
values=tensor([ 2, 65, 62, 54]),
indices=tensor([1, 2, 0, 0]))

这个计算过程是:

  1. a[dim][0],dim会从0遍历到2,也就是[1,2,2],得到第一个最大值2,index为1
  2. a[dim][1],对[5,6,65],最大值为65,index为2
  3. 最终得到上图结果

dim = 1

torch.max(a, 1)
torch.return_types.max(
values=tensor([62,  6, 65]),
indices=tensor([2, 1, 1]))

这个计算过程是:

  1. a[0][dim],dim会从0遍历到3,也就是[1,5,62,54],得到第一个最大值62,index为2
  2. a[1][dim],对[2,6,2,6],最大值为6,index为1
  3. 最终得到上图结果

三维数据详述

a = [1,2,13,4,5,6,27,8,9,0,11,12]
a = np.array(a).reshape(3,2,2)
a = torch.Tensor(a)
print(a)
tensor([[[ 1.,  2.],
         [13.,  4.]],

        [[ 5.,  6.],
         [27.,  8.]],

        [[ 9.,  0.],
         [11., 12.]]])

dim = 0

torch.max(a,dim=0)
torch.return_types.max(
values=tensor([[ 9.,  6.],
        [27., 12.]]),
indices=tensor([[2, 1],
        [1, 2]]))

计算过程:

  1. a[dim][0][0],dim会从0遍历到2,其他维数值不变,也就是[1,5,9],得到第一个最大值9,index为2
  2. a[dim][0][1],dim会从0遍历到2,其他维数值不变,对[2,6,0]遍历,最大值为6,index为1
  3. a[dim][1][0],dim会从0遍历到2,其他维数值不变,对[13,27,11]遍历,最大值为27,index为1
  4. a[dim][1][1],dim会从0遍历到2,其他维数值不变,对[13,27,11]遍历,最大值为27,index为1
  5. 最终得到上面结果

dim = 1

torch.max(a,dim=1)
torch.return_types.max(
values=tensor([[13.,  4.],
        [27.,  8.],
        [11., 12.]]),
indices=tensor([[1, 1],
        [1, 1],
        [1, 1]]))

计算过程:

  1. a[0][dim][0],dim会从0遍历到1,其他维数值不变,也就是[1,13],得到第一个最大值13,index为1
  2. a[0][dim][1],dim会从0遍历到1,其他维数值不变,对[2,4]遍历,最大值为4,index为1
  3. a[1][dim][0],dim会从0遍历到1,其他维数值不变,对[5,27]遍历,最大值为27,index为1
  4. a[1][dim][1],dim会从0遍历到1,其他维数值不变,对[6,8]遍历,最大值为8,index为1
  5. 最终得到上面结果

dim = 2

torch.max(a,dim=2)
torch.return_types.max(
values=tensor([[ 2., 13.],
        [ 6., 27.],
        [ 9., 12.]]),
indices=tensor([[1, 0],
        [1, 0],
        [0, 1]]))

计算过程:

  1. a[0][0][dim],dim会从0遍历到1,其他维数值不变,也就是[1,2],得到第一个最大值2,index为1
  2. a[0][1][dim],dim会从0遍历到1,其他维数值不变,对[13,4]遍历,最大值为13,index为0
  3. a[1][0][dim],dim会从0遍历到1,其他维数值不变,对[5,6]遍历,最大值为6,index为1
  4. a[1][1][dim],dim会从0遍历到1,其他维数值不变,对[27,8]遍历,最大值为27,index为1
  5. 最终得到上面结果

你可能感兴趣的:(pytorch笔记,pytorch,max,python)