apply_gradients
和compute_gradients
是所有的优化器都有的方法。
compute_gradients(
loss,
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None
)
计算loss
中可训练的var_list
中的梯度。
相当于minimize()
的第一步,返回(gradient, variable)
对的list。
loss:
A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable.
var_list:
Optional list or tuple of tf.Variable to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.
gate_gradients:
How to gate the computation of gradients. Can be GATE_NONE, GATE_OP, or GATE_GRAPH.
aggregation_method:
Specifies the method used to combine gradient terms. Valid values are defined in the class AggregationMethod.
colocate_gradients_with_ops:
If True, try colocating gradients with the corresponding op.
grad_loss:
Optional. A Tensor holding the gradient computed for loss.
A list of (gradient, variable) pairs. Variable is always present, but gradient can be None.
apply_gradients(
grads_and_vars,
global_step=None,
name=None
)
minimize()
的第二部分,返回一个执行梯度更新的ops。
grads_and_vars:
List of (gradient, variable)
pairs as returned by compute_gradients().
global_step:
Optional Variable to increment by one after the variables have been updated.
name:
Optional name for the returned operation. Default to the name passed to the Optimizer constructor.
An Operation that applies the specified gradients. If global_step
was not None
, that operation also increments global_step
.
#Now we apply gradient clipping. For this, we need to get the gradients,
#use the `clip_by_value()` function to clip them, then apply them:
threshold = 1.0
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
#list包括的是:梯度和更新变量的元组对
capped_gvs = [(tf.clip_by_value(grad, -threshold, threshold), var)
for grad, var in grads_and_vars]
#执行对应变量的更新梯度操作
training_op = optimizer.apply_gradients(capped_gvs)