TensorFlow中Optimizer.minimize()与Optimizer.compute_gradients()和Optimizer.apply_gradients()的用法

1、Optimizer.minimize()

实际上根据官方文档的说明,minimize()就是compute_gradients()和apply_gradients()这两个方法的简单组合,minimize()的源码如下:

  def minimize(self, loss, global_step=None, var_list=None,
               gate_gradients=GATE_OP, aggregation_method=None,
               colocate_gradients_with_ops=False, name=None,
               grad_loss=None):
    grads_and_vars = self.compute_gradients(
        loss, var_list=var_list, gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops,
        grad_loss=grad_loss)

    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
    if not vars_with_grad:
      raise ValueError(
          "No gradients provided for any variable, check your graph for ops"
          " that do not support gradients, between variables %s and loss %s." %
          ([str(v) for _, v in grads_and_vars], loss))

    return self.apply_gradients(grads_and_vars, global_step=global_step,
                                name=name)

主要的参数说明:

  •       loss:  `Tensor` ,需要优化的损失; 
  •       var_list: 需要更新的变量(tf.Varialble)组成的列表或者元组,默认值为`GraphKeys.TRAINABLE_VARIABLES`,即tf.trainable_variables()

注意:

1、Optimizer.minimize(loss, var_list)中,计算loss所涉及的变量(假设为var(loss))包含在var_list中,也就是var_list中含有多余的变量,并不 影响程序的运行,而且优化过程中不改变var_list里多出变量的值;

2、若var_list中的变量个数少于var(loss),则优化过程中只会更新var_list中的那些变量的值,var(loss)里多出的变量值 并不会改变,相当于固定了网络的某一部分的参数值。

2、compute_gradients()和apply_gradients()

compute_gradients()的源码如下:

compute_gradients(self, loss, var_list=None,
                  gate_gradients=GATE_OP,
                  aggregation_method=None,
                  colocate_gradients_with_ops=False,
                  grad_loss=None):

里面参数的定义与minimizer()函数里面的一致,var_list的默认值也一样。需要特殊说明的是,如果var_list里所包含的变量多于var(loss),则程序会报错。其返回值是(gradient, variable)对所组成的列表,返回的数据格式也都是“tf.Tensor”。我们可以通过变量名称的管理来过滤出里面的部分变量,以及对应的梯度。

apply_gradients()的源码如下:

apply_gradients(self, grads_and_vars, global_step=None, name=None)

grads_and_vars的格式就是compute_gradients()所返回的(gradient, variable)对,当然数据类型也是“tf.Tensor”,作用是,更新grads_and_vars中variable的梯度,不在里面的变量的梯度不变。

你可能感兴趣的:(TensorFlow)