对指定的部分变量梯度更新

在训练模型,有时需要对某些变量停止梯度更新,比如蒸馏时teacher的weight保持不变,有一种简单的方式通过scope控制哪些变量是否进行梯度更新,tensorflow的tf.get_collection(key, scope=None)函数获取需要更新梯度的变量:

var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=‘teacher’)

如果指定key,就返回名称域(scope)中所有放入‘key’的变量的列表,如果不指定scope则返回所有变量。

然后通过optimizer来进行梯度更新:

train_op = optimizer.minimize(self.loss, var_list=update_ops)

这时只会更新带有teacher的scope变量的权重,而其他的变量则不会进行梯度更新

https://blog.csdn.net/qq_43088815/article/details/89926074

你可能感兴趣的:(对指定的部分变量梯度更新)