保存检查点
在实际的模型训练中,TensorFlow难免会出现中断的情况,使得到的中间参数丢失,因此需要在模型训练过程中及时将模型保存下来。并将这种在训练中保存模型的操作,称为保存检查点
通过设置saver的另一个参数max_to_keep,指定生成检查点文件的个数,代码示例如下:
saver = tf.train.Saver(max_to_keep = 1)
在保存模型时可以传入迭代次数,如:
saver.save(sess,saverdir + "linearmodel.cpkt",global_step = epoch)
载入时同样也要指定迭代次数,如:
saver.restore(sess2,saverdir + "linearmodel.cpkt-" + str(final_epoch))
快速获取检查点文件
快速获取检查点文件有两种方法:
①使用get_checkpoint_state函数,传入检查点文件路径作为参数,从而找到检查点文件。该函数返回的是checkpoint 文件CheckpointState proto类型的内容,其有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的检查点文件的文件名,all_model_checkpoint_paths则是未被删除的所有保存下来的检查点文件的文件名。
final_epoch = 18
ckpt = tf.train.get_checkpoint_state("log/")
with tf.Session() as sess2:
sess2.run(init)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess2,ckpt.model_checkpoint_path)
print("x=0.5,z=",sess2.run(z,feed_dict = {X:0.5}))
②使用latest_checkpoint()函数查找最新保存的检查点文件,该方法是速度最快的。
ckpt = tf.train.latest_checkpoint("log/")
with tf.Session() as sess2:
sess2.run(init)
if ckpt != None:
saver.restore(sess2,ckpt)
print("x=0.5,z=",sess2.run(z,feed_dict = {X:0.5}))
按照训练时间保存检查点
使用MonitoredTrainingSession()函数,该函数可以直接实现保存及载入检查点模型的文件,并且可以通过设置save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。
#使用MonitoredTrainingSession()之前,必须定义global_step变量
global_step = tf.train.get_or_create_global_step()
checkpoint_step = tf.assign_add(global_step,1)
#定义检查点保存路径
saverdir = "log/checkpoints"
#启动session
with tf.train.MonitoredTrainingSession(checkpoint_dir=saverdir,save_checkpoint_secs=1) as sess:
print("global_step=",sess.run([global_step]))
#使用死循环,session不结束时就不结束
while not sess.should_stop():
i = sess.run(checkpoint_step)
print("i=",i)