np.argmax()与torch.max()

  • np.argmax()

解释:接收两个参数,第一个为np数组,第二个为axis,在数组的第axis轴上求最大值,返回数组中最大值的索引值,当一组中同时出现几个最大值时,返回第一个最大值的索引值。看例子:

import numpy as np
a = np.array([
    [
        [1, 5, 5, 2],
        [9, -6, 2, 8],
        [-3, 7, -9, 1]
    ],

    [
        [-1, 7, -5, 2],
        [9, 6, 2, 8],
        [3, 7, 9, 1]
    ],

    [
        [21, 6, -5, 2],
        [9, 36, 2, 8],
        [3, 7, 79, 1]
    ]
])

b = np.argmax(a, axis = 0)
c = np.argmax(a, axis = 1)
d = np.argmax(a, axis = 2)

print(b)
print(c)
print(d)

输出为:

>>b
[[2 1 0 0]
 [0 2 0 0]
 [1 0 2 0]]

>>c
[[1 2 0 1]
 [1 0 2 1]
 [0 1 2 1]]

>>d
[[1 0 1]
 [1 0 2]
 [0 1 2]]

分析:对于一个3*3*4的矩阵,当axis = 0时,在第一个维度上作比较,即三个矩阵作比较,返回的是一个3*4的矩阵,同理,axis = 1时在第二个维度上作比较,返回的是一个3*4的矩阵,axis = 2时在第三个维度上作比较,返回的是一个3*3的矩阵。可以发现,输出相较于输入总会减少一维,具体减少哪一维由axis决定,可以用这个来验证。

  • torch.max()

解释:传入两个参数,一个torch.tensor,一个dim,用法与np.max相似,不过这个返回两个tensor,第一个是沿着dim维的最大值,另一个是对应的索引。同时出现几个最大值时,返回最后一个最大值的索引值。

你可能感兴趣的:(np.argmax()与torch.max())