等价于:
res = fn1() if pred else fn2()
注意:pred不能是 python bool, pred是个标量Tensor i.e. tf.placeholder(dtype=tf.bool, shape=[])
官网例子
z = tf.mul(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
pred_fn_pairs
:以下两种形式都是正确的
1. [(pred_1, fn_1), (pred_2, fn_2)]
2. {pred_1:fn_1, pred_2:fn_2}
tf.case()
等价于:
if pred_1:
return fn_1()
elif pred_2:
return fn_2()
else:
return default()
import tensorflow as tf
x = tf.constant(0)
y = tf.constant(1)
z = tf.constant(2)
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f2, tf.less(x, z): f1},
default=f3, exclusive=False)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(r))
如果我们有很多 tensor
或 op
想要一起run
,这时这两个函数就是一个很好的帮手了。
w = tf.Variable(1)
mul = tf.multiply(w, 2)
add = tf.add(w, 2)
group = tf.group(mul, add)
tuple = tf.tuple([mul, add])
# sess.run(group)和sess.run(tuple)都会求Tensor(add)
#Tensor(mul)的值。区别是,tf.group()返回的是`op`
#tf.tuple()返回的是list of tensor。
#这样就会导致,sess.run(tuple)的时候,会返回 Tensor(mul),Tensor(add)的值.
#而 sess.run(group)不会
http://stackoverflow.com/questions/34877523/in-tensorflow-what-is-tf-identity-used-for
tf.while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)
while_loop
可以这么理解
loop_vars = [...]
while cond(*loop_vars):
loop_vars = body(*loop_vars)
示例:
import tensorflow as tf
a = tf.get_variable("a", dtype=tf.int32, shape=[], initializer=tf.ones_initializer())
b = tf.constant(2)
f = tf.constant(6)
# Definition of condition and body
def cond(a, b, f):
return a < 3
def body(a, b, f):
# do some stuff with a, b
a = a + 1
return a, b, f
# Loop, 返回的tensor while 循环后的 a,b,f
a, b, f = tf.while_loop(cond, body, [a, b, f])
with tf.Session() as sess:
tf.global_variables_initializer().run()
res = sess.run([a, b, f])
print(res)