从三元表达式(ternary expression)理解 numpy.where

三元表达式的一般形式:

x if condition else y

例如,我们有如下的两个数组以及条件:

xarr = np.array([1.1, 1.2, 1.3, 1.4, 1.5])

yarr = np.array([2.1, 2.2, 2.3, 2.4, 2.5])

cond = np.array([True, False, True, True, False])

当条件为 True 时,我们想从 xarr 中取值,反之从 yarr 中取值。用三元表达式需要这样来实现:

res = [(x if c else y)
       for x, y, c in zip(xarr, yarr, cond)]
res
"""
[1.1, 2.2, 1.3, 1.4, 2.5]
"""

但这个实现方法有许多问题。首先,对于大的数组执行速度不会太快。其次,不能对多维数组进行操作。


numpy.where 函数可看作是对三元表达式的向量化扩展。

numpy.where(condition, [x, y, ]/)

这里的 condition 为布尔类型的数组,x, y 也都为数组 (也可以为标量)。如果所有这些 arrays 都是一维的,那就等价于我们上面所写的三元表达式:

[(x if c else y) for c, x, y in zip(condition, x, y)]

上面的例子用 numpy.where 只需要这样写:

res = np.where(cond, xarr, yarr)
res
"""
[1.1, 2.2, 1.3, 1.4, 2.5]
"""

对于多维数组,例如,我们有如下数组:
arr = np.random.randn(5, 5)
arr
"""
array([[-0.97771187, -0.98135695, -0.13112475, -0.07527619,  1.20978508],
       [-1.12931145, -0.84098807,  2.04738178,  1.38584849, -0.51919951],
       [-0.92975612,  1.03771019, -0.08548654,  1.13116971, -0.89777143],
       [ 0.82876313, -0.73411161, -2.83065065, -1.14866989,  0.78968089],
       [ 0.03728637, -0.69337259, -1.40003486,  0.52986178,  1.34800647]])
"""

我们想把大于 0 的值替换为 2,而小于 0 的值替换为 -2:

arr > 0
"""
array([[False, False, False, False,  True],
       [False, False,  True,  True, False],
       [False,  True, False,  True, False],
       [ True, False, False, False,  True],
       [ True, False, False,  True,  True]])
"""
np.where(arr > 0, 2, -2)
"""
array([[-2, -2, -2, -2,  2],
       [-2, -2,  2,  2, -2],
       [-2,  2, -2,  2, -2],
       [ 2, -2, -2, -2,  2],
       [ 2, -2, -2,  2,  2]])
"""

只将大于 0 的值替换为 2:

np.where(arr > 0, 2, arr)
"""
array([[-0.97771187, -0.98135695, -0.13112475, -0.07527619,  2.        ],
       [-1.12931145, -0.84098807,  2.        ,  2.        , -0.51919951],
       [-0.92975612,  2.        , -0.08548654,  2.        , -0.89777143],
       [ 2.        , -0.73411161, -2.83065065, -1.14866989,  2.        ],
       [ 2.        , -0.69337259, -1.40003486,  2.        ,  2.        ]])
"""

注意

  • 使用 np.where 创建了一个新的 array。原有的 array 并未改变。
  • np.where(arr > 0, 2, arr),2 其实执行了广播操作
  • np.where(arr > 0, 2, -2),2 和 -2 都执行了广播操作

关于 NumPy 中的广播机制,可以看这篇文章:《NumPy 中的广播》。


References

Python for Data Analysis, 2 n d ^{\rm nd} nd edition. Wes McKinney.

你可能感兴趣的:(Python,编程技巧,numpy,python,数据分析)