Tensorflow笔记——断点恢复

断点恢复

在做深度学习训练的时候,由于训练时间比较长,迭代次数比较多,经常会出现无法一次完成train的情况,那么这个时候我们需要用到tensorflow中的断点恢复。不多说直接上例子

Tensorflow笔记——断点恢复_第1张图片

#step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])通过文件名得到模型保存时迭代的轮数

#tf.train.get_checkpoint_state函数会通过checkpoint文件自动找到目录中最新模型的文件名

ckpt = tf.train.get_checkpoint_state(CKPT_PATH)
if ckpt and ckpt.model_checkpoint_path:
    #加载模型
    saver.restore(sess,ckpt.model_checkpoint_path)

 

存model的时候,当前step的值被赋予到global_step, 所以 在train的时候要把 global_step的值赋给step,这样才可以从断点处计算。

你可能感兴趣的:(Tensorflow笔记——断点恢复)