Tensorflow代码中使用global_step

在一个项目的代码中看到作者不是用epoch数量,而是用一个名为n_updates的变量来控制训练何时结束的。代码如下。

def train(model: Model):
    n_updates = 10000000
    val_interval = 1000

    start = time.time()
    best = float("inf")
    for i in range(n_updates):
        loss = model.train_batch()

        if i % val_interval == 0:
            took = time.time() - start
            print("%05d/%05d %f updates/s %f loss" % (i, n_updates, val_interval / took, loss))
            val_loss = model.val_batch()
            if val_loss < best:
                best = val_loss
                model.save('./best')
            start = time.time()

经过查看项目代码,不难发现这是用于记录梯度更新的次数的变量。


class BaBiRecurrentRelationalNet(Model):
	......
    def __init__(self, is_testing):
        super().__init__()
        with tf.Graph().as_default(), tf.device('/cpu:0'):
            ......
            self.train_step = self.optimizer.apply_gradients(avg_gradients, global_step=self.global_step)
            ......

项目代码中还将其作为tensorboard日志记录的时间步。

    def train_batch(self):
        _, _loss, _logits, _answers, _indices, _summaries, _step, _train_qsize = self.session.run([self.train_step, self.loss, self.out, self.answers, self.task_indices, self.summaries, self.global_step, self.train_qsize_op], {self.is_training_ph: True})
        if _step % 1000 == 0:
            self._eval(self.train_writer, _answers, _indices, _logits, _summaries, _step)

        return _loss

    def _eval(self, writer, task_answers, task_indices, logits, summaries, train_step):
        writer.add_summary(summaries, train_step)
        ......

那么global_step变量是如何自增的呢?我猜想是梯度更新的代码帮我们完成了这件事。为了验证,查看优化器的apply_gradients源码,如下:

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the variables have been updated.
      name: Optional name for the returned operation.  Default to the name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`.
      ......
    """
    ......
    with ops.name_scope(name, self._name) as name:
		......
        if global_step is None:
            apply_updates = self._finish(update_ops, name)
        else:
            with ops.control_dependencies([self._finish(update_ops, "update")]):
                with ops.colocate_with(global_step):
                    if isinstance(global_step, resource_variable_ops.ResourceVariable):
                        # TODO(apassos): the implicit read in assign_add is slow; consider
                        # making it less so.
                        apply_updates = resource_variable_ops.assign_add_variable_op(
                            global_step.handle,
                            ops.convert_to_tensor(1, dtype=global_step.dtype), name=name)
                    else:
                        apply_updates = state_ops.assign_add(global_step, 1, name=name)
		......
        return apply_updates

有三处值得注意。

  1. 第8行,对global_step参数的解释:

global_step: Optional Variable to increment by one after the variables have been updated.

global_step是一个Variable类型的参数,在所有的网络参数结束梯度更新后,global_step会自增加一。

  1. 第21行,with ops.control_dependencies([self._finish(update_ops, "update")]):。用于流程控制,确保在梯度更新之后再将global_step自增。

  2. 第30行,apply_updates = state_ops.assign_add(global_step, 1, name=name)。使用“加法后赋值”的Operation(assign_add)使得global_step自增,自增得到的Tensor作为结果返回。

结论

使用global_step作为梯度更新次数控制整个训练过程何时停止,就相当于使用迭代次数(num of iterations)作为控制条件。在一次迭代过程,就前向传播了一个batch,并计算之后更新了一次梯度。

你可能感兴趣的:(Framework)