04_Numpy的函数np.where()—满足条件的处理

04_Numpy的函数np.where()—满足条件的处理

函数Numpy.where()可以对Numpy数组(ndarray)进行条件的指定,对满足条件的元素进行替换,修改,或一些特定的处理。同时也可以取得满足条件元素的进行索引。

对以下内容进行说明:

  • numpy.where()的概要
  • 多个条件式的使用
  • 满足条件元素的替换
  • 满足条件元素的处理
  • 满足条件元素的索引

Numpy.where()的概要


Numpy.where(condition[,x,y])
当条件(condition)满足时为真(True),返回x。当条件(condition)不满足时为假(False),返回y。返回的结果也是一个ndarray数组。

import numpy as np

a = np.arange(9).reshape((3, 3))
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(np.where(a < 4, -1, 100))
# [[ -1  -1  -1]
#  [ -1 100 100]
#  [100 100 100]]

print(np.where(a < 4, True, False))
# [[ True  True  True]
#  [ True False False]
#  [False False False]]

不使用numpy.where(),直接使用条件式,可以得到一个满足条件时为真(True),不满足时为假(False)的ndarray数组。

print(a < 4)
# [[ True  True  True]
#  [ True False False]
#  [False False False]]

多个条件式的使用


多个条件式时,用()将其分开,条件式与条件式之间可以用 &,| 逻辑运算符进行连接。但不可以使用and,or等关键字链接。

print(np.where((a > 2) & (a < 6), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100 100 100]]

print(np.where((a > 2) & (a < 6) | (a == 7), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100  -1 100]]

print((a > 2) & (a < 6))
# [[False False False]
#  [ True  True  True]
#  [False False False]]

print((a > 2) & (a < 6) | (a == 7))
# [[False False False]
#  [ True  True  True]
#  [False  True False]]

多个条件式时,不适用numpy.where(),也能返回得到一个True,False的ndarray数组。

满足条件元素的替换


满足条件或不满足条件的元素的替换可以参考上述的例子。也可以对只满足条件或,只不满足条件的元素进行替换。将数组中元素代入numpy.where()的参数x,y即可。

print(np.where(a < 4, -1, a))
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(np.where(a < 4, a, 100))
# [[  0   1   2]
#  [  3 100 100]
#  [100 100 100]]

numpy.where()返回一个新的ndarray数组,原数组不变。

a_org = np.arange(9).reshape((3, 3))
print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

a_new = np.where(a_org < 4, -1, a_org)
print(a_new)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

更替数组自身值的时候,可以这样写。

a_org[a_org < 4] = -1
print(a_org)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

满足条件元素的处理


numpy.where()不光可以返回原数组的值,也可以进行计算后,返回一个新的数组。

print(np.where(a < 4, a * 10, a))
# [[ 0 10 20]
#  [30  4  5]
#  [ 6  7  8]]

满足条件元素的索引


省略numpy.where()的参数x,y时,返回满足条件元素的角标。

print(np.where(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))

print(type(np.where(a < 4)))
# 

但是,上述的结果并不好理解。可以使用list(),zip()以及*对其进行以下整理,结果如下。

print(list(zip(*np.where(a < 4))))
# [(0, 0), (0, 1), (0, 2), (1, 0)]

这里的(0, 0), (0, 1), (0, 2), (1, 0)为原数组中满足条件元素的角标。
多维度的数组也同样适用。

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(np.where(a_3d < 5))
# (array([0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1]), array([0, 1, 2, 3, 0]))

print(list(zip(*np.where(a_3d < 5))))
# [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 1, 0)]

一维数组也同样适用。但如果也使用list(),zip()以及*对其进行整理的话,返回的结果也是一个2维的角标集合。

a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]

print(np.where(a_1d < 3))
# (array([0, 1, 2]),)

print(list(zip(*np.where(a_1d < 3))))
# [(0,), (1,), (2,)]

可以使用tolist(),将其转换成list格式。

print(np.where(a_1d < 3)[0])
# [0 1 2]

print(np.where(a_1d < 3)[0].tolist())
# [0, 1, 2]

你可能感兴趣的:(Numpy)