模型保存:Checkpoints&SavedModel

本文介绍了 Estimators 模型的保存和恢复。(官方文档连接:https://www.tensorflow.org/guide/checkpoint)
TensorFlow提供了两种模型格式:
  checkpoints:这种格式依赖于创建模型的代码。
  SavedModel:这种格式与创建模型的代码无关。

1、Checkpoints
  • checkpoints是什么?
       - 在tensorflow中checkpoints文件是一个二进制文件,用于存储所有的weights,biases,gradients和其他variables的值。.meta文件则用于存储 graph中所有的variables, operations, collections等。简言之一个存储参数,一个存储图。
      -“checkpoint”文件仅用于告知某些TF函数,这是最新的检查点文件。
      - .ckpt-meta 包含元图,即计算图的结构,没有变量的值(基本上你可以在tensorboard / graph中看到)。
      - .ckpt-data包含所有变量的值,没有结构。要在python中恢复模型,您通常会使用元数据和数据文件(但也可以使用.pb文件): saver = tf.train.import_meta_graph(path_to_ckpt_meta) saver.restore(sess, path_to_ckpt_data)
      -.ckpt-index是内部需要的某种索引来正确映射前两个文件。它通常不是必需的,可以只用.ckpt-meta和恢复一个模型.ckpt-data.pb文件可以保存您的整个图表(元+数据),要在c ++中加载和使用(但不训练)图形,通常会使用它来创建[freeze_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py),它会.pb从元数据和数据创建文件。要小心,(至少在以前的TF版本和某些人中)py提供的功能freeze_graph不能正常工作,所以你必须使用脚本版本。Tensorflow还提供了一种tf.train.Saver.to_proto()`方法。

  • 保存经过部分训练的模型
    Estimator自动将如下内容写入磁盘
      - checkpoints: 训练期间所创建的模型版本
      - event files: 包含有TensorBoard用于创建可视化图标的全部信息
    如果要指定模型的顶级存储目录,可以使用Estimator构造函数的可选参数model_dir,设置代码如下所示:

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir="./models_dir")

当调用Estimator的train方法时,Estimator会将checkpoint和其他文件保存到model_dir目录中,保存之后,这个目录中的文件如下所示:

checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta

这个目录存储的是Estimator在第一步训练开始和第200不训练结束时创建的checkpoints

  • Checkpoint频率
    默认情况下,Estimator按照如下时间将checkpoint保存到model_dir
      - 每600秒保存一次
      - 在train方法开始以及完成时都要保存checkpoint
      - 在目录中最多保留5个最近的checkpoints
    可以通过如下步骤来更改默认设置:
  1. 创建RunConfig对象来自定义设置
  2. 在实例化Estimator时,将该```RunConfig对象传递个Estimatro的config``参数
my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,
    keep_checkpoint_max = 10,
)
  • 变量的保存与恢复tf.train.Checkpoint
    TensorFlow 提供了 tf.train.Checkpoint这一强大的变量保存与恢复类,可以使用其 save()restore()方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer或者 tf.keras.Model实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:
checkpoint = tf.train.Checkpoint(model=model)

 这里tf.train.Checkpoint()接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model的模型实例 model`` 和一个继承tf.train.Optimizer的优化器 optimizer`` ,我们可以这样写:

checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

 这里myAwesomeModel 是我们为待保存的模型 model`` 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。 接下来,当模型训练完成需要保存的时候,使用(save_path_with_prefix``` 是保存文件的目录 + 前缀。):

checkpoint.save(save_path_with_prefix)

 例如,在源代码目录建立一个名为save的文件夹并调用一次 checkpoint.save('./save/model.ckpt'),我们就可以在可以在 save 目录下发现名为 ``checkpoint model.ckpt-1.index model.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个. index文件和. data ```文件,序号依次累加。
 当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)

 当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path)这个辅助函数f。例如如果save 目录下有 model.ckpt-1.indexmodel.ckpt-10.index的 10 个保存文件, tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10

总体而言,恢复与保存变量的典型代码框架如下:

# train.py 模型训练阶段

model = MyModel()
# 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
checkpoint.save('./save/model.ckpt')
# test.py 模型使用阶段

model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)             # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
# 模型使用代码

使用 TensorFlow 的 tf.train.CheckpointManager设置保存的数量:

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)  
# 此处, directory 参数为文件保存的路径, checkpoint_name 为文件名前缀(不提供则默认为 ckpt ),
#  max_to_keep 为保留的 Checkpoint 数目。

实战参考连接:https://www.jianshu.com/p/5006be1c5f59

你可能感兴趣的:(模型保存:Checkpoints&SavedModel)