TensorFlow入门12 -- Checkpoints,保存和恢复Estimator创建的模型

参考:《深度学习图像识别技术--基于TensorFlow Object Detection API 和 OpenVINO》

模型训练好了后,下一步就是保存(Save)和恢复(restore)模型,TensorFlow提供两种模型格式(Model Format)

1,Checkpoints, 该格式依赖于创建模型的代码.

2,SavedModel, 该格式不依赖于创建模型的代码.

本文主要讨论检查点(Checkpoint).

如《从数据的角度理解TensorFlow鸢尾花分类程序6》一文所述,在创建tf.estimator.DNNClassifier对象时,其构造函数__init__有一个参数:

model_dir:保存模型参数的路径。(Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.)

a.当没有指定的时候,Estimator 会将检查点文件写入由 Python 的 tempfile.mkdtemp函数选择的临时目录中。用语句 print(tempfile.gettempdir())可以查出本机的临时目录


tempfile.gettempdir

b.当指定了目录的时候,例如:model_dir = 'models/iris',Estimator 会将检查点文件写入~/models/iris


有了保存检查点文件路径后,tf.estimator.DNNClassifier对象会在运行train方法的时候,写入检查点文件,如下图所示:


train方法负责写入检查点文件

那train方法以什么频率写入检查点文件呢?

默认情况下,Estimator 按照以下时间安排将检查点保存到 model_dir 中:

a.每 10 分钟(600 秒)写入一个检查点。

b.在 train 方法开始(第一次迭代)和完成(最后一次迭代)时写入一个检查点。

c.只在目录中保留 5 个最近写入的检查点。

保存好检查点文件后,如何恢复模型呢?

Estimator 将一个检查点保存到 model_dir 中后,每次调用 Estimator 的 train、eval 或 predict 方法时,都会发生下列情况:

a) Estimator 通过运行 model_fn() 构建模型图。(要详细了解 model_fn(),请参阅创建自定义 Estimator。)

b) Estimator 根据最近写入的检查点中存储的数据来初始化新模型的权重。

换言之,如下图所示,一旦存在检查点,TensorFlow 就会在您每次调用 train()、evaluate() 或 predict() 时重建模型。


不当恢复

通过检查点恢复模型的状态这一操作仅在模型和检查点兼容时可行。例如,假设训练了一个 tf.estimator.DNNClassifier,它包含 2 个隐藏层且每层都有 10 个节点;在训练之后(TensorFlow已在 models/iris 中创建检查点),将每个隐藏层中的神经元数量从 10 更改为 3,然后重新训练模型,由于检查点中的状态与 修改后tf.estimator.DNNClassifier 中描述的模型不兼容,因此重新训练失败并出现以下错误,如下图所示:


不当恢复

解决不当恢复

1,当模型参数一直在变化的时候,最简单的方式是,不要指定model_dir,这样TensorFlow不会启动Checkpoint模型恢复,方便你随时修改模型。

2,启动Checkpoint的情况下,用Git为每个 model-dir 所需的代码保存一个副本,即为每个模型版本创建一个单独的 git 分支。这种区分将有助于保证检查点的可恢复性。

总结:检查点提供了一种简单的自动机制来保存和恢复由 Estimator 创建的模型。

你可能感兴趣的:(TensorFlow入门12 -- Checkpoints,保存和恢复Estimator创建的模型)