TensorFlow的tf.where函数详解与例子

官方说明:
If both x and y are None, then this operation returns the coordinates of true elements of condition. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.

If both non-None, condition, x and y must be broadcastable to the same shape.

The condition tensor acts as a mask that chooses, based on the value at each element, whether the corresponding element / row in the output should be taken from x (if true) or y (if false).

官方文档很抽象,必须结合例子来理解。一共有两种用法,分别是带有xy参数和不带这两个参数的用法。

用法1

a1=np.array([[1,0,0],[0,1,1]]) 
a2=np.array([[3,2,3],[4,5,6]])
tf.where(tf.equal(a1,1),a1,a2)

输出的结果是


也就是,当condition为真,也就是tf.equal(a1,1,即a1中的元素为1,返回的数组中所在位置元素来自a1,否则来自b1。输出的数组中,原数组a1不等于1的元素被替换成了对应位置b1中的元素。
再来一个例子,

tf.where(tf.equal(a1,1),a1,100+a1)

输出的结果是


数组a1中不等于1的元素,其值加上100。

用法2

不带xy参数的时候,返回满足condition的元素所在位置。需要关注的是返回值的形式。

tf.where(tf.equal(a,1)) 

输出结果


这是一个(3, 2)数组,行数表示满足条件的元素的数目a1中一共有3个元素为1,所有行数为3。每一列代表的是符合条件的元素的坐标,比如第一个元素[0,0],表示第一个满足条件的元素的index是(0,0)。

你可能感兴趣的:(TensorFlow的tf.where函数详解与例子)