声明:
tf.control_dependencies
是tensorflow中的一个flow顺序控制机制,作用有二:插入依赖(dependencies)和清空依赖(依赖是op或tensor)。常见的tf.control_dependencies
是tf.Graph.control_dependencies
的装饰器,它们用法是一样的。通过本文,你将了解:
control_dependencies()
的顺序控制机制tf.control_dependencies()
在batch normalization中的使用示例control_dependencies()
两种不正确的使用方式tf.control_dependencies()
有一个参数control_inputs
(这是一个列表,列表中可以是Operation
或Tensor
对象),返回一个上下文管理器(通常与with
一起使用)。
with tf.control_dependencies([a, b, c]):
d = ...
e = ...
session在运行d、e之前会先运行a、b、c。在with tf.control_dependencies
之内的代码块受到顺序控制机制的影响。
with tf.control_dependencies([a, b]):
with tf.control_dependencies([c, d]):
e = ...
session在运行e之前会先运行a、b、c、d。因为依赖会随着with tf.control_dependencies
的嵌套一直继承下去。
with tf.control_dependencies([a, b]):
with tf.control_dependencies(None): # 第二层上下文管理器
with tf.control_dependencies([c, d]):
e = ...
session在运行e之前会先运行c、d,不需要运行a、b。因为在第二层的上下文管理器中,参数control_inputs
的值为None
,如此将会清除之前所有的依赖。
tf.layers.batch_normalization()
中用到的变量——当前估计的均值和方差是untrainable的,它们通过每个batch的均值和标准差的移动平均更新值,位于collection——tf.GraphKeys.UPDATE_OPS
中。因此,需要在每一轮迭代前插入这个操作:
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
这样,在每次迭代调用train_step
之前都会先当前batch下的均值和方差的移动平均值。
op或tensor在流图中的顺序由它们的创建位置决定。
# 不正确
def my_func(pred, tensor):
t = tf.matmul(tensor, tensor)
with tf.control_dependencies([pred]):
# matmul op的定义在context之外,context内只有一个op或tensor不能继承依赖。
return t
# 应改为
def my_func(pred, tensor):
with tf.control_dependencies([pred]):
# 应将t的创建放到context之内
t = tf.matmul(tensor, tensor)
return t
tensorflow在求导过程中自动忽略常数项。
# 不正确
loss = model.loss()
with tf.control_dependencies(dependencies):
loss = loss + tf.constant(1)
return tf.gradients(loss, model.variables)
因为常数项tf.constant(1)
在back propagation时被忽略了,所以依赖性在BP的时候也不会被执行。
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)
ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
ema_val = ema.average(update)
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(3):
print(sess.run([ema_val]))
# 应改为
with tf.control_dependencies([ema_op]):
ema_val = tf.identity(ema.average(update)) # 加一个identity
看起来好像是在运行ema_val
之前先执行ema_op
,实际不然。因为ema.average(update)不是一个op,它只是从
ema对象的一个字典中取出键对应的
tensor`而已。这个清空跟上文例一很像。
import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)
ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
w1 = tf.Variable(2.0)
ema_val = ema.average(update)
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(3):
print(sess.run([ema_val, w1]))
这种情况下,control_dependencies
也不工作,原因如下:
#这段代码出现在Variable类定义文件中第287行,
# 在创建Varible时,tensorflow是移除了dependencies了的
#所以会出现 control 不住的情况
with ops.control_dependencies(None):
...