tensorflow 恢复部分参数、加载指定参数

多分类采用与训练模型输出不匹配,我们需要加载部分预训练模型的参数。

我们先看一下如何保存和读入预训练模型。

#一般实验情况下保存的时候,都是用的saver类来保存,如下
saver = tf.train.Saver()
saver.save(sess,"model.ckpt")

#加载时的代码
saver.restore(sess,"model.ckpt")

#前面的描述相当于是保存了所有的参数,然后加载所有的参数。
#但是目前的情况有所变化了,不能加载所有的参数,最后一层的参数不一样了,需要随机初始化。
#首先对每一层添加name scope,如下:

with name_scope('conv1'):
        xxx
with name_scope('conv2'):
        xxx
with name_scope('fc1'):
        xxx
with name_scope('output'):
        xxx
#然后根据变量的名字,选择加载哪些变量,

#得到该网络中,所有可以加载的参数
variables = tf.contrib.framework.get_variables_to_restore()
#删除output层中的参数
variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output']
#构建这部分参数的
saversaver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')

#在tensorflow中,有多种方式可以得到变量的信息:
tf.contrib.framework.get_variables_to_restore()
tf.all_variables()tf.trainable_varialbes()

 

 

 

 

多分类采用与训练模型输出不匹配解决方法:

利用tf.contrib.framework.get_variables_to_restore()函数,代码如下

variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['resnet50/fc'])
saver = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, param_path)

 

exclude=['resnet50/fc']表示加载预训练参数中除了resnet50/fc这一层之外的其他所有参数。

include=["inceptionv3"]表示只加载inceptionv3这一层的所有参数。

param_path是你预训练参数保存地址。

注:如果不止一个层参数需要丢弃,exclue=['a', 'b']即可。调优训练(fine_tuning)时最好把前面曾trainable设为False,只训练最后一层。
 

你可能感兴趣的:(tensorflow,Python学习)