tensorflow模型保存与载入

用Tensorflow做一些事情,最近遇到的问题是训练到的模型怎么用?

网上搜了一下,最常用的就是使用tf.train.Saver()定义saver变量,然后分别使用saver.save()和saver.restore()载入。

保存代码:

saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"results/model") 

保存后,在results目录会生成四个文件:

  • checkpoint:该文件记录了模型文件的路径信息
  • model.data-00000-of-0001:网络权重信息
  • model.index:二进制文件,保存了模型中的变量参数信息
  • model.meta:保存了模型的计算图结构信息

tensorflow版本是1.14.0,若tf2.0,因为不再用计算图了,应该会有所不同,以后用到再说。

当然,还可以将中间的训练数据保存下来,而不仅保存最终状态:

saver.save(agent.sess, 'result\\model',global_step=i,write_meta_graph=True)

加上global_step后在迭代过程中保存模型,会在模型文件后加上“-i”,生成一系列的文件。载入的时候可以载入其中任意一个进行查看了。其实这个是我需要的,不只要最终,过程的也要看一下。

载入代码:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('result\\model-141.meta')
    #saver.restore(sess, tf.train.latest_checkpoint('result\\'))
    saver.restore(sess, 'result\\model-141' )
    w1=sess.run('w:0')
    print(w1)

载入141次迭代的模型,并输出权重w1


能存能重载了就没再管,现在2020年了,查看结果的时候发现,模型只保存了5个训练步的,并没有保存所有。查看了tf.train.Saver()和saver.save的参数,才算明白,再来更新一波。

Saver类的构造函数定义如下:

def __init__(self,
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=saver_pb2.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None):

变量定义:
var_list: 指定要保存的变量的序列或字典,默认为None,即保存所有变量
reshape: 可选布尔型参数,True,表示允许变量以不同的形状保存,False,表示保持的变量只能有同样一种形状和数据类型,默认为False;
max_to_keep: 定义最多保存多少个最近模型文件,默认是5个,这正是我保存的模型只有五个的原因;
keep_checkpoint_every_n_hours: 定义多少个小时保存模型一次,默认10000个小时;
name: 可选参数,添加到操作名称前的前缀,默认None;
restore_sequentially:定义在设备上是否按照顺序恢复变量,顺序恢复可以降低内参使用,默认False;
saver_def:可选参数,用在需要重建Saver对象场合,默认None;
allow_empty:是否允许保存一个没有任何变量的空图,默认False;

saver.save函数定义:

def save(self,
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix="meta",
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False):

常用参数:
sess: 当前的会话环境;
save_path: 模型保存路径;
global_step: 训练轮次,前文已说明;
latest_filename: checkpoint文本文件的名称,默认为‘checkpoint’
meta_graph_suffix: 保存的网络图结构文件的后缀,默认为mata;
write_meta_graph: 定义是否保存网络结构,默认是True保存,由于网络结构在训练过程中是不会变的,所以保存过一次后可以设置 write_meta_graph为False,不用每次都保存图结构;

你可能感兴趣的:(Tensorflow学习)