TensorFlow笔记——(1)理解tf.control_dependencies与control_flow_ops.with_dependencies

引言

我们在实现神经网络的时候经常会看到tf.control_dependencies的使用,但是这个函数究竟是什么作用,我们应该在什么情况下使用呢?今天我们就来一探究竟。

理解

其实从字面上看,control_dependencies 是控制依赖的意思,我们可以大致推测出来,这个函数应该使用来控制就算图节点之间的依赖的。其实正是如此,tf.control_dependencies()是用来控制计算流图的,给图中的某些节点指定计算的顺序。

原型分析

tf.control_dependencies(self, control_inputs)
 arguments:control_inputs: A list of `Operation` or `Tensor` objects 
which must be executed or computed before running the operations 
defined in the context. (注意这里control_inputs是listreturn:  A context manager that specifies control dependencies 
for all operations constructed within the context.

通过以上的解释,我们可以知道,该函数接受的参数control_inputs,是Operation或者Tensor构成的list。返回的是一个上下文管理器,该上下文管理器用来控制在该上下文中的操作的依赖。也就是说,上下文管理器下定义的操作是依赖control_inputs中的操作的,control_dependencies用来控制control_inputs中操作执行后,才执行上下文管理器中定义的操作。

例子1

如果我们想要确保获取更新后的参数,name我们可以这样组织我们的代码。

opt = tf.train.Optimizer().minize(loss)

with tf.control_dependencies([opt]): #先执行opt
  updated_weight = tf.identity(weight)  #再执行该操作

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  sess.run(updated_weight, feed_dict={...}) # 这样每次得到的都是更新后的weight

可以看到以上的例子用到了tf.identity(),至于为什么要使用tf.identity(),我在下一篇博客:TensorFlow笔记——(1)理解tf.control_dependencies与control_flow_ops.with_dependencies中有详细的解释,不懂的可以移步了解。

control_flow_ops.with_dependencies

除了常用tf.control_dependencies()我们还会看到,control_flow_ops.with_dependencies(),其实连个函数都可以实现依赖的控制,只是实现的方式不太一样。

with_dependencies(dependencies, output_tensor, name=None)
Produces the content of `output_tensor` only after `dependencies`.
所有的依赖操作完成后,计算output_tensor并返回
  In some cases, a user may want the output of an operation to be
  consumed externally only after some other dependencies have run
  first. This function ensures returns `output_tensor`, but only after all
  operations in `dependencies` have run. Note that this means that there is
  no guarantee that `output_tensor` will be evaluated after any `dependencies`
  have run.

  See also @{tf.tuple$tuple} and @{tf.group$group}.

  Args:
    dependencies: Iterable of operations to run before this op finishes.
    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
    name: (Optional) A name for this operation.

  Returns:
    Same as `output_tensor`.

  Raises:
    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 

例子2

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #从一个集合中取出变量,返回的是一个列表
......
total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train)
......
# tf.group()将多个tensor或者op合在一起,然后进行run,返回的是一个op
update_op = tf.group(*update_ops)
train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                          name='train_op')

可以看到以上的例子用到了tf.group(),至于为什么要使用tf.identity(),我在下一篇博客:TensorFlow笔记——(2) tf.group(), tf.tuple 和 tf.identity()中有详细的解释,不懂的可以移步了解。

参考文档

1、tensorflow学习笔记(四十一):control dependencies
2、tf.control_dependencies与tf.identity组合详解

你可能感兴趣的:(TensorFlow)