tf.where

  1. 介绍参数形式
  2. 介绍最简单的用法,解释返回值的含义
  3. 根据官方API文档,补充更多的调用形式

参数

tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

返回值类型

x = tf.constant([29.05088806,  27.61298943,  31.19073486,  29.35532951])
indices = tf.where(x > 30)

结果
[[0]
 [2]
 [3]]

这里返回值类型是shape=[m, n]

  • m代表满足条件的个数
  • n代表输入的bool Tensor的维度,对应满足条件的元素在原Tensor的坐标
    • 在本例中,由于输入Tensor是一维的,所以输出较为直观,可以理解为输出的每一个元素对应着一维list的index
sess = tf.InteractiveSession()
x = tf.constant([29.05088806,  27.61298943,  31.19073486,  29.35532951])
x_reshaped = tf.reshape(x,[2,-1])

indices = tf.where(x_reshaped > 28)
result = sess.run(indices)

print(result)
print(type(result))
print(result.shape)

结果
[[0 0]
 [1 0]
 [1 1]]

(3, 2)

[a, b, …, ] 分别代表第0为的index为a,第1维的index为b,以此类推,维度n是由输入的Tensor的维度决定的

更多的调用形式

之前的俩个例子对应着文档的一类情况,即x=None,y=None,下面将讨论x,y不为None的俩种情况

Note:

  • x,y的维度一定要相同

用例一

condition的维度跟x,y的维度相同,实现element-wise选择,如果condition的元素为true,选择x中相应位置的元素作为输出,否则选择y中相应位置的元素作为输出

sess = tf.InteractiveSession()
x = tf.constant([29.05088806,  27.61298943,  31.19073486,  29.35532951,
  30.97266006,  26.67541885,  38.08450317,  20.74983215,
  34.94445419,  34.45999146,  29.06485367,  36.01657104,
  27.88236427,  20.56035233,  30.20379066,  29.51215172,
  33.71149445,  28.59134293,  36.05556488,  28.66994858])

x = tf.reshape(x,[5,-1]) 

a = tf.zeros_like(x)
b = tf.ones_like(x)

index = tf.where(tf.greater(x,30),a,b)
print(sess.run(index))

输出
[[1. 1. 0. 1.]
 [0. 1. 0. 1.]
 [0. 0. 1. 0.]
 [1. 1. 0. 1.]
 [0. 1. 0. 1.]]

用例二

condition的维度是一维,大小对应x的第0维,功能也是根据condition的True/False,分别从x,y中选出相应的元素,只不过粒度不同

# YOUR CODE
x = tf.constant([29.05088806,  27.61298943,  31.19073486,  29.35532951,
33.71149445,  28.59134293,  36.05556488,  28.66994858])

a = tf.zeros([8,2,2])
b = tf.ones([8,2,2])

index = tf.where(tf.greater(x,30),a,b)
print(sess.run(tf.greater(x,30)))
print(sess.run(index))

输出
False False  True False  True False  True False]
[[[1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]]

 [[0. 0.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]]

 [[0. 0.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]]

 [[0. 0.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]]]

你可能感兴趣的:(TensorFlow,TensorFlow,where)