TensorFlow中多模型载入,以及多种加载参数的方式

最近做一个课题,要讲两个网络接在一起,自己摸索了好几天,查阅了h各种资料,最后勉强解决了问题.下面总结一下:

首先是保存参数的方法:
我最初都是用的最简单的方式保存图的参数,然后加载图的参数,见下面代码:

 saver = tf.train.Saver()
 saver.restore(sess, args.dir_models + 'model.ckpt')#加载模型
 saver.save(sess, (args.dir_models + 'model.ckpt'))#保存模型

相信大家这个都会,也是最开始接触的方式,但是遇到后面多模型导入,指定参数导入,就会不知所措。

下面就说明一下怎么多模型参数导入,不用声明两个sess,那样子我试过,被坑了。
看下面代码:(有注释)

 var = tf.global_variables()#取出全局中所有的参数
 var_flow_restore = [val for val in var if 'flownet' in val.name]#取出名字中有‘flownet’的参数
 saver = tf.train.Saver(var_flow_restore)#这句话就是关键了,可以网Saver中传参数
 saver.restore(sess, flow_model_dir + 'model.ckpt')#然后就往sess对应的图中导入了参数(var_flow_restore)

上面就导入了一个模型中自己想要导入的参数(这里我导入的是带有‘flownet’关键字的参数)

好了,在介绍一种导入参数的形式:(我感觉有缺陷)

exclude = ['discriminator','flownet']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
#这是slim 自带的方法,就是加载了所有变量,但是这些变量不包括带有['discriminator','flownet']关键字的参数,这里的['discriminator','flownet']只能位于参数名字的开头(这就是缺陷,要是我想去除的关键字不在参数名字开头呢)

saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, old_model_dir_model + 'snap-0')

最后介绍一种导入参数方法:

vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='导入参数的最开头的名字,也就是命名空间')
# print(vars_list)
 assign_ops = []
 for var in vars_list:
     # print('var_value:',var.value)
     vname = var.name
     from_name = vname
     # print('from_name:',from_name)
     var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
     # print('var_value:',var_value)
     assign_ops.append(tf.assign(var, var_value))

sess.run(assign_ops)

我研究了以上几种参数导入方法,最后我觉得第一种是最好的,下面是我用第一种方法导入另外一个模型代码:

variables_to_restore = [val for val in tf.global_variables() if 'inpaint_net' in val.name and 
                         'Adam' not in val.name and 'Adam_1' not in val.name]
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, old_model_dir_model + 'snap-0')

可以思考一下。

最后保存模型,也希望值保存自己想要保存的参数
说白了,就是网tf.train.Saver()中传入参数列表(list)
见代码

saver = tf.train.Saver(自己想要保存的参数list)
saver.save(sess,(model_dir_model+ 'model.ckpt'))

好了。总结完毕!

你可能感兴趣的:(python,深度学习)