tensorflow 冻结梯度

一、背景

        有时候在训练模型的时候,会有这样的需求:某个loss,只想影响一部分网络参数的更新,而另外一部分网络参数不想受这个loss的影响,特别是像多目标的多塔结构的模型。

二、实现

a = weight1 + weight2
a_stopped = tf.stop_gradient(a)
y3 = a_stopped + weight3

gradients1 = tf.gradients(y3, [weight1, weight2, weight3], grad_ys=[tf.convert_to_tensor([1., 2.])])
gradients2 = tf.gradients(y3, [weight3], grad_ys=[tf.convert_to_tensor([1., 2.])])
print(gradients1)  # [None, None, < tf.Tensor 'gradients_1/grad_ys_0:0' shape = (2,) dtype = float32 >]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    '''
    下面代码会报错
    因为weight1、weight2 的梯度被停止,程序试图去求一个None的梯度,所以报错
    注释掉gradients1
    求 gradients2 就又正确了
    '''
    # print(sess.run(gradients1))
    print(sess.run(gradients2))

对于y3来说,就相当于把a_stopped从变量(和weight1和weight2有关)变成一个常量,所以对他求导不需要在执行链式求导法则,对常量求偏导就是等于0

你可能感兴趣的:(tensorflow,TensorFlow,梯度冻结)