TensorFlow入门(七、检查点)

保存检查点

在实际的模型训练中,TensorFlow难免会出现中断的情况,使得到的中间参数丢失,因此需要在模型训练过程中及时将模型保存下来。并将这种在训练中保存模型的操作,称为保存检查点

通过设置saver的另一个参数max_to_keep,指定生成检查点文件的个数,代码示例如下:

saver = tf.train.Saver(max_to_keep = 1)

在保存模型时可以传入迭代次数,如:

saver.save(sess,saverdir + "linearmodel.cpkt",global_step = epoch)

载入时同样也要指定迭代次数,如:

saver.restore(sess2,saverdir + "linearmodel.cpkt-" + str(final_epoch))

快速获取检查点文件

快速获取检查点文件有两种方法:

①使用get_checkpoint_state函数,传入检查点文件路径作为参数,从而找到检查点文件。该函数返回的是checkpoint 文件CheckpointState proto类型的内容,其有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的检查点文件的文件名,all_model_checkpoint_paths则是未被删除的所有保存下来的检查点文件的文件名。

final_epoch = 18
ckpt = tf.train.get_checkpoint_state("log/")
with tf.Session() as sess2:
    sess2.run(init)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess2,ckpt.model_checkpoint_path)
    print("x=0.5,z=",sess2.run(z,feed_dict = {X:0.5}))

②使用latest_checkpoint()函数查找最新保存的检查点文件,该方法是速度最快的。

ckpt = tf.train.latest_checkpoint("log/")
with tf.Session() as sess2:
    sess2.run(init)
    if ckpt != None:
        saver.restore(sess2,ckpt)
    print("x=0.5,z=",sess2.run(z,feed_dict = {X:0.5}))

按照训练时间保存检查点

使用MonitoredTrainingSession()函数,该函数可以直接实现保存及载入检查点模型的文件,并且可以通过设置save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

#使用MonitoredTrainingSession()之前,必须定义global_step变量
global_step = tf.train.get_or_create_global_step()
checkpoint_step = tf.assign_add(global_step,1)

#定义检查点保存路径
saverdir = "log/checkpoints"

#启动session
with tf.train.MonitoredTrainingSession(checkpoint_dir=saverdir,save_checkpoint_secs=1) as sess:
    print("global_step=",sess.run([global_step]))
    #使用死循环,session不结束时就不结束
    while not sess.should_stop():
        i = sess.run(checkpoint_step)
        print("i=",i)

你可能感兴趣的:(TensorFlow入门,tensorflow,人工智能,python)