output_gradients 在 tf.GradientTape()中的作用

tf.keras.optimizers.Optimizer 拥有method minimize()

minimize(
	loss, var_list, grad_loss=None, name=None, tape=None
	)

其中参数grad_loss的解释是(Optional). A Tensor holding the gradient computed for loss. 令人费解。其实和tf.GradientTape()的methodgradient()中的 output_gradients是一致的:

gradient(
    target,
    sources,
    output_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

output_gradients解释为a list of gradients, one for each element of target. Defaults to None. 依旧令人费解,其实是对梯度附加了一个乘子。可见tensorflow\python\eager\imperative_grad.py中的描述:

output_gradients: if not None, a list of gradient provided for each Target,
   	 			   or None if we are to use the target's computed downstream gradient.

可见,output_gradients是用以承载上游梯度的,但就一般而言,是不需要使用的,一个合适的适用场景是是手动链式求导,具体的:

import tensorflow as tf 
optimizer =  tf.keras.optimizers.SGD(1.0)
w = [tf.Variable([0.0,0.1]),tf.Variable([0.0,0.2])]
x = 2.0
with tf.GradientTape(persistent=True) as tape:
    y = x*(w[0]+w[1])
tf.print(y) # [0 0.6]
grads = tape.gradient(y,w)
tf.print(grads) # [[2 2], [2 2]]
optimizer.minimize(y,w,grad_loss=[tf.constant([1.,2.]),tf.constant([2.,2.])],tape=tape)
tf.print(w) # [[-2 -3.9], [-2 -3.8]]

w = tf.Variable([1.0,0.0])
x = 2.0
with tf.GradientTape() as tape:
    y = x*w
optimizer.minimize(y,w,grad_loss=tf.constant([2.0,5.0]),tape=tape)
tf.print(w) # [-3 -10]

w = [tf.Variable([0.0,0.1]),tf.Variable([0.0,0.2])]
x = 2.0
with tf.GradientTape(persistent=True) as tape:
    y = tf.reduce_mean(x*(w[0]+w[1]))
    y2 = 7.0*y
grads = tape.gradient(y2,w)
tf.print(grads) # [[7 7], [7 7]]
grads = tape.gradient(y,w,output_gradients=tape.gradient(y2,y)) # 手动链式求导
tf.print(grads) # [[7 7], [7 7]]
# output_gradients = tf.constant([-1.])
output_gradients = -1.0
grads = tape.gradient(y2,w,output_gradients=output_gradients) # 人为干预
tf.print(grads) # [[-7 -7], [-7 -7]]

上述代码中,人为干预可以修改梯度方向,也就可以有一个骚操作,即比如在GAN算法中,只要求一次判别器loss, 优化生成器时给予判别器loss一个-1.0的乘子即可。当然,这个操作是不推荐的,因为一般优秀的GAN算法判别器和生成器loss是有所区别的,另外就coding层面,这样的设计虽然骚,但是不利于代码理解与阅读。

你可能感兴趣的:(机器学习笔记,Python学习笔记,tensorflow,深度学习,机器学习)