numpy.where
函数是三元表达式true_value if condition else false_value
的向量化版本。
numpy.where
的基本语法假如有一个布尔值数组和两个数值数组:
In [3]: xarr = np.array([1.1, 1.2, 1.3, 1.4, 1.5])
In [4]: yarr = np.array([2.1, 2.2, 2.3, 2.4, 2.5])
In [5]: cond = np.array([True, False, True, True, False])
假设cond中的元素为True时,选择xarr中对应的元素,否则选择yarr中对应的元素。我们可以使用Python中列表推导式来完成:
In [6]: [[x if cond else y] for (x, cond, y) in zip(xarr, cond, yarr)]
Out[6]: [[1.1], [2.2], [1.3], [1.4], [2.5]]
这种方法会产生两个问题:一是处理速度慢,二是当数组使多维时,这种方法就会失效。NumPy数组中专门用于逻辑判断的是numpy.where
函数:
In [7]: np.where(cond, xarr, yarr)
Out[7]: array([1.1, 2.2, 1.3, 1.4, 2.5])
numpy.where
函数第二三个参数不要求必须是数组,还可以是标量。如:我们随机生成一个多维数组,将其中的正值替换为1,负值替换为-1。In [10]: arr = np.random.randn(4, 4)
In [11]: arr
Out[11]:
array([[-0.93233381, 0.92597708, -1.0829083 , -1.47264049],
[ 0.54360661, 0.13305021, -0.58304843, -0.62027412],
[ 0.03978567, 0.70116782, -0.70640549, 0.42405301],
[-1.2909755 , -2.87462369, 0.0081447 , 0.39662395]])
In [12]: arr > 0 # 通过条件判断生成一个布尔值数组
Out[12]:
array([[False, True, False, False],
[ True, True, False, False],
[ True, True, False, True],
[False, False, True, True]])
In [13]: np.where(arr>0, 1, -1)
Out[13]:
array([[-1, 1, -1, -1],
[ 1, 1, -1, -1],
[ 1, 1, -1, 1],
[-1, -1, 1, 1]])
numpy.where
将标量和数组结合起来,可以替换数组的中的特殊值。如将arr数组中的负数替换成0:In [14]: np.where(arr>0, arr, 0)
Out[14]:
array([[0. , 0.92597708, 0. , 0. ],
[0.54360661, 0.13305021, 0. , 0. ],
[0.03978567, 0.70116782, 0. , 0.42405301],
[0. , 0. , 0.0081447 , 0.39662395]])
总结,numpy.where
中的第二三个参数既可以是同等大小的数组,也可以是标量。