Tensorflow中使用tf.train.Saver()和saver.restore()进行参数的保存和重现

1.Saver背景介绍

1.1我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。


1.2.Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。


1.3只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。
(第3点的内容我还没有理解上,以后写好了再更新哈)

 

2.代码

2.1 tf.train.Saver()如何写

2.1.1写的位置

(1)在with tf.Session() as sess:这一句的前面写就行,不用再Session()会话中,例如下面:

checkpoint_steps = 1
check_dir='D:/train/'
saver = tf.train.Saver()


log_dir = "D:/Temp/logs4/"

with tf.Session() as sess:
    
    train_writer = tf.summary.FileWriter(log_dir + "train/", sess.graph)  # 记录默认图
    test_writer = tf.summary.FileWriter(log_dir + "test/")

(2)在Session()会话的内部结尾中再加这么一句,指明地址

saver.save(sess,check_dir + 'model.ckpt') 

这两句就够了。

2.2 saver.restore()

2.2.1在哪里写

在训练结束全部结束后,再执行以下代码即可

saver = tf.train.Saver()
with tf.Session() as sess:
    model_file=tf.train.latest_checkpoint('D:/train/')
    saver.restore(sess,model_file)
    print("W4",sess.run(W4))

要注意到:print中的“W4”,这个是在定义网络结构中已经定义好的,这就需要在执行的时候,看清楚你当时定义的是什么,就在这里填。我的执行结果如下:

Tensorflow中使用tf.train.Saver()和saver.restore()进行参数的保存和重现_第1张图片

你可能感兴趣的:(tensorflow)