np.argmax输出的是什么?

重点说一下 argmax输出的是什么东西

test = np.array([[1, 2, 3],
                 [2, 3, 4],
                 [5, 4, 3],
                 [8, 7, 2]])  # 构建一个4X3的矩阵
out = np.argmax(test, axis=1)  # axis=1:按行查找最大元素 axis=0:按列查找最大元素
print(out)

输出是:

[2 2 0 0]  # 按行查找出的最大元素的索引号

输出值第一个元素为什么是2 ---> 按行索引 

                                            --->第一行[1, 2, 3] 最大值是3

                                            --->3的索引值是2(数组,从0开始)

                                            --->第二行[2, 3, 4] 最大值是4

                                            --->4的索引值是2

                                            --->第三行[5, 4, 3] 最大值是5

                                            --->5的索引值是0

                                            --->第四行[8, 7, 2] 最大值是8

                                            --->8的索引值是0

----->所以输出值是[2 2 0 0]

按列索引同理

那么用处在哪呢?就我用到的地方就是使用MNIST数据集的时候,因为这个数据集用一个矩阵表示了一张图片表示数字几。

[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]

也就像这样一个矩阵。第一行在索引5处有个1,说明这行数据表示对应的图片是手写的数字5....以此类推

使用argmax按行取出最大值的索引值就可以得到一个数字的矩阵了。

actuals = np.argmax(test, axis=1)
[5 5 2 3 7 3]  # 使用argmax对上面的数据集的输出

                                    



你可能感兴趣的:(np.argmax输出的是什么?)