numpy.where()
主要有两种用法,如下:
解释: 满足条件输出x,否则输出y
代码片段1:
import numpy as np
aa = np.arange(10) # [0 1 2 3 4 5 6 7 8 9]
print(np.where(aa < 5, aa, aa * 10))
输出结果为:
[ 0 1 2 3 4 50 60 70 80 90]
代码片段2:
print(np.where([[True, False], [True, True]], [[1, 2], [3, 4]], [[9, 8], [7, 6]]))
输出结果为:
# congdition为一个shape为(2, 2)的类型为布尔的numpy数组,x和y同样对应为(2, 2)的numpy数组
# 逐个分析condition的元素,由于condition[0][0]为True,故对应的值应选择x[0][0],类推...
# 因此最终输出的元素值分别为x[0][0],y[0][1],x[1][0],x[1][1]
[[1 8]
[3 4]]
解释: 返回满足条件的数组下标
代码片段1:
# 1维数组
arr1 = np.array([1, 2, 3, 4, 5]) # shape为(5, )
print(np.where(arr > 2))
print('---------')
arr2 = np.array([[1, 2, 3, 4, 5]]) # shape为(1, 5)
arr2_where = np.where(arr > 2)
print(arr2_where)
print(np.array(arr2_where))
输出结果为:
(array([2, 3, 4], dtype=int64),)
---------
# 操作2维数组所得到结果也为2维数组,结果中,第一行元素为原数
# 组的第一个维度的下标,第二行元素为原数组中第二个维度的下标。
# arr2[0][2]=3,arr2[0][3]=4,arr2[0][4]=5为满足条件的三个元素。
(array([0, 0, 0], dtype=int64), array([2, 3, 4], dtype=int64))
[[0 0 0]
[2 3 4]]
再看一个三维数组的示例,代码片段2:
arr3 = np.arange(8).reshape(2, 2, 2)
print(arr3)
arr3_where = np.where(arr3 > 2)
print(arr3_where)
print(np.array(arr3_where))
输出结果为:
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
(array([0, 1, 1, 1, 1], dtype=int64), array([1, 0, 0, 1, 1], dtype=int64), array([1, 0, 1, 0, 1], dtype=int64))
# 容易看到arr3[0][1][1]=3满足condition,其它以此类推
[[0 1 1 1 1]
[1 0 0 1 1]
[1 0 1 0 1]]
遇到再补充…