【深度学习笔记(十六)】之tensorflow2中模型的保存与恢复

本文章由公号【开发小鸽】发布!欢迎关注!!!


老规矩–妹妹镇楼:

一. tensorflow2中模型的保存

(一)概述

       Tensorflow2中提供了 tf.train,Checkpoint这个变量保存与恢复类,可以使用save()和restore()方法将 Tensorflow中所有包含CheckPointable State的对象进行保存和恢复。

       如,优化器tf.keras.optimizer,变量tf.Variable,层tf.keras.Layer和模型tf.keras.Model实例都可以被保存。

(二)Checkpoint对象的保存

       要保存CheckPoint对象,首先需要声明一个Checkpoint,tf.train.Checkpoint()接受的初始化参数是一系列的键值对,键名任意取,值为需要保存的对象。这里的键名非常重要,在恢复模型时,需要使用同样的键名。

checkpoint = tf.train.Checkpoint(model=model)

       在模型训练完成后,对模型进行保存,使用save函数保存,参数为保存的目录名。

checkpoint.save(path)

如:

checkpoint.save(./save/model.ckpt’)

       save目录下就会生成三个文件:

checkpoint
model.ckpt-1.index
model.ckpt-1.data-00000-of-00001

       save()方法能够运行多次,每次运行后都能得到一个.index文件和一个.data文件,文件名的序号依次增加。


二. Checkpoint对象的恢复

(一)恢复模型

       当我们需要为某个模型加载之前已经训练好的参数时,需要再次实例化一个checkpoint,键名需要和保存的键名一致,再调用restore()方法恢复模型参数,restore()方法的参数为之前保存的文件的目录+编号。

       如”./save/model.ckpt-1”,跟之前保存的目录名相比添加了编号1,即第一个保存的文件。

如:

#待恢复参数的同一模型
model_to_be_restored = MyModel()
#新建checkpoint
checkpoint = tf.train.Checkpoint(model=model_to_be_restored)
checkpoint.restore(path_index)

(二)恢复最近的一个Checkpoint

       使用tf.train.latest_checkpoint(path)返回目录下最近一次checkpoint的文件名。

如:

checkpoint.restore(tf.train.latest_checkpoint(./save”))

三.保存与恢复变量代码模板

       先实例化checkpoint,继而训练,保存模型,再在其他模型中恢复保存的参数。

train.py

model = MyModel()
checkpoint = tf.train.CheckPoint((myModel=model)
#训练模型
checkpoint.save(./save/model.ckpt”)

test.py

model = MyModel()
checkpoint = tf.train.CheckPoint(myModel=model)
checkpoint.restore(tf.train.latest_checkpoint(./save”))
#模型使用

四. 管理checkpoint

       当checkpoint保存的数量过多时,我们可能只想保存最后几个checkpoint,或者希望使用其他的编号方式,这些都要通过checkpointManager来实现。

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory=./save”, checkpoint_name=’model.ckpt’, max_to_keep=k)

       directory :文件保存的路径
       checkpoint_name :文件名前缀(不提供则默认为
ckpt )
       max_to_keep :保留的Checkpoint数目。

       然后使用manager,save()保存即可,如需要修改编号的形式,则在保存时添加checkpoint_number参数,如:将编号设置为100

manager.save(checkpoint_number=100)

你可能感兴趣的:(#,深度学习,python,tensorflow,深度学习)