函数Numpy.where()可以对Numpy数组(ndarray)进行条件的指定,对满足条件的元素进行替换,修改,或一些特定的处理。同时也可以取得满足条件元素的进行索引。
对以下内容进行说明:
Numpy.where(condition[,x,y])
当条件(condition)满足时为真(True),返回x。当条件(condition)不满足时为假(False),返回y。返回的结果也是一个ndarray数组。
import numpy as np
a = np.arange(9).reshape((3, 3))
print(a)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
print(np.where(a < 4, -1, 100))
# [[ -1 -1 -1]
# [ -1 100 100]
# [100 100 100]]
print(np.where(a < 4, True, False))
# [[ True True True]
# [ True False False]
# [False False False]]
不使用numpy.where(),直接使用条件式,可以得到一个满足条件时为真(True),不满足时为假(False)的ndarray数组。
print(a < 4)
# [[ True True True]
# [ True False False]
# [False False False]]
多个条件式时,用()将其分开,条件式与条件式之间可以用 &,| 逻辑运算符进行连接。但不可以使用and,or等关键字链接。
print(np.where((a > 2) & (a < 6), -1, 100))
# [[100 100 100]
# [ -1 -1 -1]
# [100 100 100]]
print(np.where((a > 2) & (a < 6) | (a == 7), -1, 100))
# [[100 100 100]
# [ -1 -1 -1]
# [100 -1 100]]
print((a > 2) & (a < 6))
# [[False False False]
# [ True True True]
# [False False False]]
print((a > 2) & (a < 6) | (a == 7))
# [[False False False]
# [ True True True]
# [False True False]]
多个条件式时,不适用numpy.where(),也能返回得到一个True,False的ndarray数组。
满足条件或不满足条件的元素的替换可以参考上述的例子。也可以对只满足条件或,只不满足条件的元素进行替换。将数组中元素代入numpy.where()的参数x,y即可。
print(np.where(a < 4, -1, a))
# [[-1 -1 -1]
# [-1 4 5]
# [ 6 7 8]]
print(np.where(a < 4, a, 100))
# [[ 0 1 2]
# [ 3 100 100]
# [100 100 100]]
numpy.where()返回一个新的ndarray数组,原数组不变。
a_org = np.arange(9).reshape((3, 3))
print(a_org)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
a_new = np.where(a_org < 4, -1, a_org)
print(a_new)
# [[-1 -1 -1]
# [-1 4 5]
# [ 6 7 8]]
print(a_org)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
更替数组自身值的时候,可以这样写。
a_org[a_org < 4] = -1
print(a_org)
# [[-1 -1 -1]
# [-1 4 5]
# [ 6 7 8]]
numpy.where()不光可以返回原数组的值,也可以进行计算后,返回一个新的数组。
print(np.where(a < 4, a * 10, a))
# [[ 0 10 20]
# [30 4 5]
# [ 6 7 8]]
省略numpy.where()的参数x,y时,返回满足条件元素的角标。
print(np.where(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))
print(type(np.where(a < 4)))
#
但是,上述的结果并不好理解。可以使用list(),zip()以及*对其进行以下整理,结果如下。
print(list(zip(*np.where(a < 4))))
# [(0, 0), (0, 1), (0, 2), (1, 0)]
这里的(0, 0), (0, 1), (0, 2), (1, 0)为原数组中满足条件元素的角标。
多维度的数组也同样适用。
a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
#
# [[12 13 14 15]
# [16 17 18 19]
# [20 21 22 23]]]
print(np.where(a_3d < 5))
# (array([0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1]), array([0, 1, 2, 3, 0]))
print(list(zip(*np.where(a_3d < 5))))
# [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 1, 0)]
一维数组也同样适用。但如果也使用list(),zip()以及*对其进行整理的话,返回的结果也是一个2维的角标集合。
a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]
print(np.where(a_1d < 3))
# (array([0, 1, 2]),)
print(list(zip(*np.where(a_1d < 3))))
# [(0,), (1,), (2,)]
可以使用tolist(),将其转换成list格式。
print(np.where(a_1d < 3)[0])
# [0 1 2]
print(np.where(a_1d < 3)[0].tolist())
# [0, 1, 2]