编程速记(26): Tensorflow篇-模型的恢复与保存-tf.saver

一、需求

1.每次找到文件夹中保存的最后一次epoch的次数
2.如果1.存在,那么将模型恢复到该epoch的参数
3.每跑完一个epoch,保存模型参数

二、demo

准备工作

1.定义一个保存epoch的文件夹路径save_dir
2.定义保存模型的文件名pattern,比如:epoch-300.
例如:如果当前是第10个epoch,则参数保存为model_10.pth

tf.saver机制简述

基于saver.save()一般会生成四个文件。

  • 尾缀为data的文件,主要保存模型的参数数据,以字典的形式
  • 尾缀为meta的文件,主要保存模型的图和元数据,可以被 tf.train.import_meta_graph 加载到当前默认的图。
  • checkpoint文件,这是一个监督文件,可以帮助恢复到当前最后一个保存的模型的参数(tf.train.latest_checkpoint(save_model_dir))
  • 尾缀为index的文件

实现

代码片段:初始化

g_list = tf.global_variables()
saver = tf.train.Saver(var_list=g_list)
save_model_dir = os.path.join('./models', args.save_model)
if not os.path.exists(save_model_dir):
    os.makedirs(save_model_dir)

代码片段:模型保存

#save
save_path = saver.save(sess, os.path.join(save_model_dir, 'epoch-{}'.format(epoch)), global_step=global_cnt)

代码片段:模型恢复

#restore
if os.path.exists(os.path.join(save_model_dir,'checkpoint')):
    init_epoch = re.findall(".*epoch-(.*)-.*",tf.train.latest_checkpoint(save_model_dir))# tf.train.latest_checkpoint(path)返回的是path提供的路径中的最新的模型文件名称(无后缀)——基于checkpoint文件
    saver.restore(sess,tf.train.latest_checkpoint(save_model_dir))
    init_epoch = int (init_epoch[0])
    print("restored from epoch :{}".format(init_epoch))

三、参考

https://www.geek-share.com/detail/2752211124.html

你可能感兴趣的:(编程速记)