一、简介
def cond(pred, # 谓词,可以理解为判断条件
true_fn=None, # 当谓词为真(True)时返回的函数
false_fn=None, # 当谓词为假(False)时返回的函数
strict=False, #
name=None,
fn1=None,
fn2=None):
API注释:
Return true_fn()
if the predicate pred
is true else false_fn()
.
true_fn
and false_fn
both return lists of output tensors. true_fn
and false_fn
must have the same non-zero number and type of outputs.
Note that the conditional execution applies only to the operations defined in true_fn
and false_fn
. Consider the following simple program:
python
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
If x < y
, the tf.add
operation will be executed and tf.square
operation will not be executed. Since z
is needed for at least one branch of the cond
, the tf.multiply
operation is always executed, unconditionally.
Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics.
Note that cond
calls true_fn
and false_fn
exactly once (inside the call to cond
, and not at all during Session.run()
). cond
stitches together the graph fragments created during the true_fn
and false_fn
calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred
.
tf.cond
supports nested structures as implemented in tensorflow.python.util.nest
. Both true_fn
and false_fn
must return the same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by true_fn
and/or false_fn
, they are implicitly unpacked to single values. This behavior is disabled by passing strict=True
.
Google翻译:
如果谓词pred
为真,则返回true_fn()
,否则返回false_fn()
。
true_fn
和false_fn
都返回输出张量列表。 true_fn
和false_fn
必须具有相同的非零数字和输出类型。
请注意,条件执行仅适用于true_fn
和false_fn
中定义的操作。考虑以下简单程序:
z = tf.multiply(a,b)
result = tf.cond(x lambda:tf.add(x,z),lambda:tf.square(y))
如果x
tf.add
操作并且不执行tf.square
操作。由于cond
的至少一个分支需要z
,所以总是无条件地执行tf.multiply
操作。
虽然这种行为与TensorFlow的数据流模型一致,但它偶尔会让一些期望更加懒惰语义的用户感到惊讶。
注意cond
只调用一次true_fn
和false_fn
在cond
的调用中,在Session.run()
期间不调用)。 cond
将在true_fn
和false_fn
调用期间创建的图形片段与一些额外的图形节点拼接在一起,以确保根据pred
的值执行正确的分支。
tf.cond
支持在tensorflow.python.util.nest
中实现的嵌套结构。 true_fn
和false_fn
都必须返回列表,元组和/或命名元组的相同(可能是嵌套的)值结构。
单例列表和元组构成了对此的唯一例外:当由true_fn
和/或false_fn
返回时,它们被隐式解压缩为单个值。通过传递strict = True
禁用此行为。
总结:该函数类似与if...else...
分支,当谓词判断为真时,调用前面一个函数,谓词判断为假时则调用后面一个函数。这在写程序时很有用,因为在TensorFlow中,我们需要先建立Graph,此时数据是不可知的,常规方法并不能直接判断,这里就提供了一个借口,可以在数据未知时进行判断。pred: A scalar determining whether to return the result of true_fn
or
false_fn
.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables ‘strict’ mode; see above.
name: Optional name prefix for the returned tensors.
二、参数
在实际的使用过程中,我们一般只需要使用以下参数即可。
参数 | ||
---|---|---|
pred |
A scalar determining whether to return the result of true_fn or false_fn . |
一个标量,或者说是一个判断条件,用以判断返回true_fn 或者 false_fn |
true_fn |
The callable to be performed if pred is true. |
当 pred 为真时,返回的函数 |
false_fn |
The callable to be performed if pred is false. |
当 pred 为假时,返回的函数 |
strict |
A boolean that enables/disables ‘strict’ mode; see above. | 一个bool值,表示是否使用’strict’模式,详见上 |
name |
Optional name prefix for the returned tensors. | 名称,可选参数 |
三、代码
import tensorflow as tf
import numpy as np
x = tf.constant(2)
y = tf.constant(1)
def f1(): return tf.multiply(x, 17)
def f2(): return tf.add(y, 23)
r = tf.cond(tf.less(x, y), f1, f2)
with tf.Session() as sess:
print(sess.run(r))
运行结果:因为2<1为False,执行f2,得到结果1+23=24
24
import tensorflow as tf
import numpy as np
x = tf.constant(2)
y = tf.constant(5) # 与前面程序的区别仅仅是y取值不同
def f1(): return tf.multiply(x, 17)
def f2(): return tf.add(y, 23)
r = tf.cond(tf.less(x, y), f1, f2)
with tf.Session() as sess:
print(sess.run(r))
运行结果:因为2<5为True,这里执行f1,返回2*17=34。
34
为了方便,也可以使用lambda来定义函数。
# coding=utf-8
import tensorflow as tf
import numpy as np
a = tf.placeholder(dtype=tf.float32)
# 随便定义一些计算逻辑
b = tf.add(a, 32)
c = tf.add(a, 56)
res = tf.cond(a < 10, lambda: b + 10, lambda: c * 2)
with tf.Session() as sess:
print(sess.run(res, feed_dict={a: 13}))
计算结果:
138.0