tenflow 中tf.where()用法
where(condition, x=None, y=None, name=None)
condition, x, y 相同维度,condition是bool型值,True/False
1,where(condition)的用法
condition是bool型值,True/False
返回值,是condition中元素为True对应的索引
看个例子:
-
import tensorflow
as tf
-
a = [[
1,
2,
3],[
4,
5,
6]]
-
b = [[
1,
0,
3],[
1,
5,
1]]
-
condition1 = [[
True,
False,
False],
-
[
False,
True,
True]]
-
condition2 = [[
True,
False,
False],
-
[
False,
True,
False]]
-
with tf.Session()
as sess:
-
print(sess.run(tf.where(condition1)))
-
print(sess.run(tf.where(condition2)))
结果1:
-
[[
0
0]
-
[
1
1]
-
[
1
2]]
结果2:
-
[[
0
0]
-
[
1
1]]
2, where(condition, x=None, y=None, name=None)的用法
condition, x, y 相同维度,condition是bool型值,True/False
返回值是对应元素,condition中元素为True的元素替换为x中的元素,为False的元素替换为y中对应元素
x只负责对应替换True的元素,y只负责对应替换False的元素,x,y各有分工
由于是替换,返回值的维度,和condition,x , y都是相等的。
看个例子:
-
import tensorflow
as tf
-
x = [[
1,
2,
3],[
4,
5,
6]]
-
y = [[
7,
8,
9],[
10,
11,
12]]
-
condition3 = [[
True,
False,
False],
-
[
False,
True,
True]]
-
condition4 = [[
True,
False,
False],
-
[
True,
True,
False]]
-
with tf.Session()
as sess:
-
print(sess.run(tf.where(condition3,x,y)))
-
print(sess.run(tf.where(condition4,x,y)))
结果:
-
1, [[
1
8
9]
-
[
10
5
6]]
-
2, [[
1
8
9]
-
[
4
5
12]]