numpy.unravel_index()
函数的作用是获取一个/组int
类型的索引值在一个多维数组中的位置。
官方文档:
举个例子:
我有一个ndarray数组A,A.shape = [3, 3, 3, 2],A.dtype=np.int64。如何找到A中最大元素的索引???
yan@yanubuntu:~$ python
Python 2.7.12 (default, Nov 12 2018, 14:36:49)
[GCC 5.4.0 20160609] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> A=np.random.randint(1,100,size=(3,3,3,2))
>>> A
array([[[[25, 36],
[78, 83],
[97, 11]],
[[32, 37],
[14, 10],
[72, 92]],
[[34, 72],
[90, 61],
[62, 31]]],
[[[59, 41],
[53, 12],
[33, 62]],
[[72, 25],
[ 9, 19],
[64, 93]],
[[76, 42],
[98, 21],
[31, 40]]],
[[[58, 27],
[64, 78],
[52, 34]],
[[63, 45],
[79, 3],
[78, 2]],
[[ 5, 31],
[84, 15],
[64, 38]]]])
我可以使用np.argmax()
函数来完成:
>>> ind_max=np.argmax(A)
>>> ind_max
32
此时得到的最大值索引是将A进行flatten成一维数组之后的索引值,如何得到最大元素在原数组A中的索引呢?这就是np.unravel_index()
函数做的事情:
>>> ind_max_src=np.unravel_index(ind_max, A.shape)
>>> ind_max_src
(1, 2, 1, 0)
>>> A[1,2,1,0]
98
这里np.unravel_index()
函数的第一个参数indices
除了可以是int型标量值,还可以是一个int型数组,当indices
为int型数组时,就是对数组中的每一个元素执行相同的运算过程。具体输出格式参考官方文档中的Examples
。
另外一个函数numpy.ravel_multi_index(),执行和np.unravel_index()
函数相反的运算。