深度学习学习笔记——tensorflow中的tf.Variable与函数trainable_variable及冻结网络

tensorflow中tf.constant()往往生成常数不可变,而tf.Constant()往往生成变量(如神经网络中待训练的权重与偏置参数),之所以可变是因为这个函数有一个参数,trainable默认为True(trainable=True)。
一、Model.trainable_variables
返回modle中所有的可训练参数,通常使用在梯度下降法时,如:

    with tf.GradientTape() as t:
        z = tf.random.normal(shape=(args.batch_size, 1, 1, args.z_dim))
        x_fake = G(z, training=True)
        x_fake_d_logit = D(x_fake, training=True)
        G_loss = g_loss_fn(x_fake_d_logit)

    G_grad = t.gradient(G_loss, G.trainable_variables)

二、model.trainable
“冻结”一个层指的是该层将不参加网络训练,即该层的权重永不会更新。在进行fine-tune时我们经常会需要这项操作。 在使用固定的embedding层处理文本输入时,也需要这个技术。

可以通过向层的构造函数传递trainable参数来指定一个层是不是可训练的,如:

frozen_layer = Dense(32,trainable=False)

此外,也可以通过将层对象的trainable属性设为True或False来为已经搭建好的模型设置要冻结的层。 在设置完后,需要运行compile来使设置生效,例如:

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)

frozen_model = Model(x, y)
# in the model below, the weights of `layer` will not be updated during training
frozen_model.compile(optimizer='rmsprop', loss='mse')

layer.trainable = True
trainable_model = Model(x, y)
# with this model the weights of the layer will be updated during training
# (which will also affect the above model since it uses the same layer instance)
trainable_model.compile(optimizer='rmsprop', loss='mse')

frozen_model.fit(data, labels)  # this does NOT update the weights of `layer`
trainable_model.fit(data, labels)  # this updates the weights of `layer`

转自

你可能感兴趣的:(深度学习学习笔记,深度学习)