Tensorflow:如何加载神经网络部分参数

需求如下先用Resnet50在ImageNet上预训练,最后一层输出为类别数量,设为1000。然后将保存下来的参数迁移到PascalVoc上训练。

问题:由于PascalVoc只有20类,所以Resnet50最后一层输出要改为20。此时直接用tf.train.Saver()的restore,因为预训练的参数最后一层resnet50/fc长度为1000,而新模型最后一层resnet50/fc长度为20,会不匹配造成加载失败。

解决方法:利用tf.contrib.framework.get_variables_to_restore()函数,代码如下

variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['resnet50/fc'])
saver = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, param_path)

exclude=['resnet50/fc']表示加载预训练参数中除了resnet50/fc这一层之外的其他所有参数。

param_path是你预训练参数保存地址。

注:如果不止一个层参数需要丢弃,exclue=['a', 'b']即可。

调优训练(fine_tuning)时最好把前面曾trainable设为False,只训练最后一层。

你可能感兴趣的:(Tensorflow)