tf.control_dependencies()

声明:

  1. 翻译tensorflow官方文档并进行了总结
  2. 参考博客tensorflow学习笔记(四十一):control dependencies

tf.control_dependecies()

tf.control_dependencies是tensorflow中的一个flow顺序控制机制,作用有二:插入依赖(dependencies)和清空依赖(依赖是op或tensor)。常见的tf.control_dependenciestf.Graph.control_dependencies的装饰器,它们用法是一样的。通过本文,你将了解:

  • 了解control_dependencies()的顺序控制机制
  • tf.control_dependencies()在batch normalization中的使用示例
  • control_dependencies()两种不正确的使用方式

control_dependencies介绍

tf.control_dependencies()有一个参数control_inputs(这是一个列表,列表中可以是OperationTensor对象),返回一个上下文管理器(通常与with一起使用)。

例1

with tf.control_dependencies([a, b, c]):
	d = ...
	e = ...

session在运行d、e之前会先运行a、b、c。在with tf.control_dependencies之内的代码块受到顺序控制机制的影响。

例2

with tf.control_dependencies([a, b]):
	with tf.control_dependencies([c, d]):
		e = ...

session在运行e之前会先运行a、b、c、d。因为依赖会随着with tf.control_dependencies的嵌套一直继承下去。

例3

with tf.control_dependencies([a, b]):
	with tf.control_dependencies(None):  # 第二层上下文管理器
		with tf.control_dependencies([c, d]):
			e = ...

session在运行e之前会先运行c、d,不需要运行a、b。因为在第二层的上下文管理器中,参数control_inputs的值为None,如此将会清除之前所有的依赖。

在BN中的使用

tf.layers.batch_normalization()中用到的变量——当前估计的均值方差是untrainable的,它们通过每个batch的均值和标准差的移动平均更新值,位于collection——tf.GraphKeys.UPDATE_OPS中。因此,需要在每一轮迭代前插入这个操作:

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
	train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)

这样,在每次迭代调用train_step之前都会先当前batch下的均值方差的移动平均值。

几种不正确的使用方式

例1

op或tensor在流图中的顺序由它们的创建位置决定。

# 不正确
def my_func(pred, tensor):
	t = tf.matmul(tensor, tensor)
	with tf.control_dependencies([pred]):
		# matmul op的定义在context之外,context内只有一个op或tensor不能继承依赖。
		return t

# 应改为
def my_func(pred, tensor):
	with tf.control_dependencies([pred]):
		# 应将t的创建放到context之内
		t = tf.matmul(tensor, tensor)
		return t

例2

tensorflow在求导过程中自动忽略常数项。

# 不正确
loss = model.loss()
with tf.control_dependencies(dependencies):
	loss = loss + tf.constant(1)
return tf.gradients(loss, model.variables)

因为常数项tf.constant(1)在back propagation时被忽略了,所以依赖性在BP的时候也不会被执行。

例3

w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
	ema_val = ema.average(update)

with tf.Session() as sess:
	tf.global_variables_initializer().run()
	for i in range(3):
		print(sess.run([ema_val]))

# 应改为
with tf.control_dependencies([ema_op]):
	ema_val = tf.identity(ema.average(update))  # 加一个identity

看起来好像是在运行ema_val之前先执行ema_op,实际不然。因为ema.average(update)不是一个op,它只是从ema对象的一个字典中取出键对应的tensor`而已。这个清空跟上文例一很像。

例4

import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
    w1 = tf.Variable(2.0)
    ema_val = ema.average(update)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(3):
        print(sess.run([ema_val, w1]))

这种情况下,control_dependencies也不工作,原因如下:

#这段代码出现在Variable类定义文件中第287行,
# 在创建Varible时,tensorflow是移除了dependencies了的
#所以会出现 control 不住的情况
with ops.control_dependencies(None):
    ...    

你可能感兴趣的:(TensorFlow)