tensorflow 恢复部分参数、加载指定参数

现实中碰到一个问题,训练好分类模型,比如训练保存了一个10分类的模型,但是实际用的时候呢,可能是做20分类,但是还想继续使用前面保存的模型。那么相当于是只加载前几层的参数,最后一层做一些修改。


一般实验情况下保存的时候,都是用的saver类来保存,如下

saver = tf.train.Saver()
saver.save(sess,"model.ckpt")
加载时的代码

saver.restore(sess,"model.ckpt")


前面的描述相当于是保存了所有的参数,然后加载所有的参数。但是目前的情况有所变化了,不能加载所有的参数,最后一层的参数不一样了,需要随机初始化。如何操作呢?

首先对每一层添加name scope,如下:

with name_scope('conv1'):
        xxx
with name_scope('conv2'):
        xxx
with name_scope('fc1'):
        xxx
with name_scope('output'):
        xxx

然后根据变量的名字,选择加载哪些变量,

#得到该网络中,所有可以加载的参数
variables = tf.contrib.framework.get_variables_to_restore()
#删除output层中的参数
variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output']
#构建这部分参数的saver
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')

在tensorflow中,有多种方式可以得到变量的信息:

tf.contrib.framework.get_variables_to_restore()
tf.all_variables()
tf.trainable_varialbes()

等等,可以多看看API

你可能感兴趣的:(tensorflow)