np.argmax()函数用法解析——通俗易懂

  • 目录
    • 简介
    • 一维数组用法
    • 多维数组用法
      • 二维
      • 高维

0. 简介

numpy.argmax(array, axis) 用于返回一个numpy数组中最大值的索引值。当一组中同时出现几个最大值时,返回第一个最大值的索引值。

在运算时,相当于剥掉一层中括号,返回一个数组,分为一维和多维。一维数组剥掉一层中括号之后就成了一个索引值,是一个数,而n维数组剥掉一层中括号后,会返回一个 n-1 维数组,而剥掉哪一层中括号,取决于axis的取值。

n维的数组的 axis 可以取值从 0 到 n-1,其对应的括号层数为从最外层向内递进,详见后文。

一、一维数组的用法

one_dim_array = np.array([1, 4, 5, 3, 7, 2, 6])
print(np.argmax(one_dim_array))

运算后,降一维,成为一个数值,7的索引值维4,所以运算结果:

4

二、多维数组的用法

1. 二维

遵循运算之后降一维的原则,因此返回的会是一个一维的array。同时,axis的取值为0和1,对应剥掉的中括号,将里面的内容直接按逗号分隔:
0 —— 外层
1 —— 内层

举例如下:

two_dim_array = np.array([[1, 3, 5], [0, 4, 3]])
max_index_axis0 = np.argmax(two_dim_array, axis = 0)
max_index_axis1 = np.argmax(two_dim_array, axis = 1)
print(max_index_axis0)
print(max_index_axis1)

输出结果

[0 1 0] 
[2 1]

这里的two_dim_array是一个 2×3 的矩阵,对应axis为:
array axis
2 —— 0
3 —— 1

所以,在axis为0时,剥掉2,返回一个1×3的数组;在axis为1时,剥掉3,返回一个1×2的数组

two_dim_array = np.array([[1, 3, 5], [0, 4, 3]])
max_index_axis0 = np.argmax(two_dim_array, axis = 0)
"""
在axis为0时,0层括号置外面提出来,0 层内按 "," 换行对齐,
[[1, 3, 5], 
 [0, 4, 3]]
分别对已经对齐的元素按照 0 层括号外面的逗号分组,这里的
[1, 3, 5], 
[0, 4, 3] 
0 层外面没有逗号,因此是一组
然后按列比较大小即可,返回值为
[argmax(1,0), argmax(3,4), argmax(5,3)]:
[0, 1, 0]
"""

同样的思路可以用在axis为1时

max_index_axis1 = np.argmax(two_dim_array, axis = 1)

"""
在axis为1时,则从外向内 1 层的中括号,提出来,1层内按 "," 换行对齐
[[1, 
  3, 
  5], 
 [0, 
  4, 
  3]]
分别对已经对齐的元素按照 1 层外逗号分组,[1,3,5]一组,[0.4.3]一组。
每组元素进行比较,将 1 层括号变成argmax()
[argmax(1,3,5),
 argmax(0,4,3)]
返回值为
[2, 
 1]
"""

2. 高维

以三维为例,计算思路与二维相同。

三维计算之后降维,将返回一个二维数组。

一个m×n×p维的矩阵,
axis为0,舍去m,返回一个 n×p 维的矩阵
axis为1,舍去n,返回一个 m×p 维的矩阵
axis为2,舍去p,返回一个 m×n 维的矩阵

three_dim_array = [[[1, 2, 3, 4],  [-1, 0, 3, 5]],
				   [[2, 7, -1, 3], [0, 3, 12, 4]],
				   [[5, 1, 0, 19], [4, 2, -2, 13]]]
a = np.argmax(three_dim_array, axis = 0)
print(a)
b = np.argmax(three_dim_array, axis = 1)
print(b)
c = np.argmax(test, axis = 2)
print(c)

例中数组shape为 3×2×4
输出结果为:
0 对应shape 2×4
1 对应shape 3×4
2 对应shape 3×2

[[2 1 0 2]                                                                                                               
 [2 1 1 2]]

[[0 0 0 1]                                                                                                               
 [0 0 1 1]                                                                                                               
 [0 1 0 0]]

[[3 3]                                                                                                                   
 [1 2]                                                                                                                   
 [3 3]]  

由于原理类似,因此以axis = 1 举例解析

b = np.argmax(three_dim_array, axis = 1)
print(b)
"""
保留0层1层中括号,1层内按照逗号进行换行对齐
[[[1, 2, 3, 4],  
  [-1, 0, 3, 5]],
 [[2, 7, -1, 3], 
  [0, 3, 12, 4]],
 [[5, 1, 0, 19], 
  [4, 2, -2, 13]]]

按1层外面逗号分组
[[1, 2, 3, 4],  
 [-1, 0, 3, 5]]为一组
[[2, 7, -1, 3], 
 [0, 3, 12, 4]]为一组
[[5, 1, 0, 19], 
 [4, 2, -2, 13]]为一组
对每组内,按列进行操作,并去掉2层括号
三组分别为:
[argmax(1,-1),argmax(2,0),argmax(3,3),argmax(4,5)]
[argmax(2,0),argmax(7,3),argmax(-1,12),argmax(3,4)]
[argmax(5,4),argmax(1,2),argmax(0,-2),argmax(19,13)]

进而,结果为
[[0 0 0 1]                                                                                                               
 [0 0 1 1]                                                                                                               
 [0 1 0 0]]
"""

当axis为0和2时一样,分组后如下:

axis = 0
three_dim_array = [[[1, 2, 3, 4],  [-1, 0, 3, 5]],
				   [[2, 7, -1, 3], [0, 3, 12, 4]],
				   [[5, 1, 0, 19], [4, 2, -2, 13]]]
axis = 2
three_dim_array = [[[1,
					 2,
		  			 3, 
		  			 4],  
					[-1, 
		  			 0, 
		 			 3, 
		 			 5]],
				   [[2, 
		  			 7, 
		 			 -1, 
		 			 3], 
					[0, 
		 			 3, 
		 			 12, 
		 			 4]],
				   [[5, 
		 			 1, 
		 			 0, 
		  			 19], 
		 			[4, 
		 			 2, 
		 			 -2, 
		 			 13]]]

你可能感兴趣的:(np.argmax()函数用法解析——通俗易懂)