TensorFlow笔记(10) CheckPoint

TensorFlow笔记(10) CheckPoint

  • 1. checkpoint
  • 2. 保存检查点
  • 3. 读取检查点


1. checkpoint

随着数据的复杂性和网络深度的加深,训练的强度就会加大
万一电脑训练太久炸裂,或者突然怎么了断电了和不小心关掉了
那岂不是得重新训练

不存在的,可以用保存检查点checkpoint,把当前所有可训练变量值存储到检查点文件
需要时再重新拿来加载就行了,这样就可以实现断点训练


2. 保存检查点

为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoint file)
首先需要生成saver

saver = tf.train.Saver()

在训练循环中,将定期调用saver.save()方法
向训练文件夹中写入包含了当前所有可训练变量值得检查点文件

saver.save(sess, FLAGS.train_dir, global_step=step)

举个简单的例子

import tensorflow as tf

# 定义变量
a = tf.Variable(1, name="a")
b = tf.Variable(2, name="b")

# 定义保存模型
saver = tf.train.Saver()
save_dir = "../save_path/test_model/"

# 定义模型序号
step = 0

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.initialize_all_variables())
    print("v1 =", a.eval())
    print("v2 =", b.eval())
    save_path = saver.save(sess, save_dir + "model", global_step=step)
    print("Model saved in file: ", save_path)

# v1 = 1
# v2 = 2
# Model saved in file:  ../save_path/test_model/model-0

这样就生成了checkpoint file
TensorFlow笔记(10) CheckPoint_第1张图片


3. 读取检查点

使用saver.restore()方法,重载模型的参数,继续运行

saver.restore(sess, FLAGS.train_dir)

举个简单的例子,读取刚刚存储的checkpoint file

import tensorflow as tf

# 定义变量
a = tf.Variable(0, name="a")
b = tf.Variable(0, name="b")

# 定义保存模型
saver = tf.train.Saver()
save_dir = "../save_path/test_model/"

# 定义模型序号
step = 0

with tf.Session() as sess:
    # 恢复保存模型
    # 如果有检查点文件, 读取最新的检查点文件,恢复各种变量值
    ckpt_dir = tf.train.latest_checkpoint(save_dir)
    if ckpt_dir != None:
        saver.restore(sess, ckpt_dir)
    else:
        # 变量初始化
        sess.run(tf.initialize_all_variables())
    # 或者直接读取
    # saver.restore(sess, save_dir + "model-{}".format(step))
    print("v1 =", a.eval())
    print("v2 =", b.eval())

# v1 = 1
# v2 = 2

可以看到,读取的不是现在初始化0,而是原先模型的数值


[1] python的代码地址:
https://github.com/JoveH-H/TensorFlow/blob/master/py/7.1.CheckPoint_svae.py
https://github.com/JoveH-H/TensorFlow/blob/master/py/7.2.CheckPoint_restore.py
[2] jupyter notebook的代码地址:
https://github.com/JoveH-H/TensorFlow/blob/master/ipynb/7.1.CheckPoint_svae.ipynb
https://github.com/JoveH-H/TensorFlow/blob/master/ipynb/7.2.CheckPoint_restore.ipynb


相关推荐:

TensorFlow笔记(9) ResNet
TensorFlow笔记(8) LeNet-5卷积神经网络
TensorFlow笔记(7) 多神经元分类
TensorFlow笔记(6) 单神经元分类
TensorFlow笔记(5) 多元线性回归


谢谢!

你可能感兴趣的:(TensorFlow笔记)