Tensorflow2.x 模型保存(持续更新)

至少从我个人使用的角度
Tensorflow2.x 的模型保存、checkpoint 的官方教程也太乱了。
简单的流程被说的云里雾里的
一般流程如下

定义模型
定义检查点
restore 如果没有已有的保存 则直接创建新的检查点
训练,更新
save

如果训练到一半断电 关闭 恢复时会可以从指定检查点,一般是最近检查点恢复。
下面是几点问题

1. Checkpoint

my_checkpoint = tf.train.Checkpoint(model=model)

构建自己的checkpoint实例。
支持
tf.Variables 如果是模型之外的变量,记得assign更新 而非+= := 等更新
tf.keras.optimizers.Optimizer
tf.data.Dataset
tf.keras.Layer
tf.keras.Model
适用于以上trackable objects

一般的Checkpoint的root参数默认为None 且不填
其余的kwargs 会变成my_checkpoint 的属性。

2. checkpoint_directory && checkpoint_prefix

checkpoint_directory 只存放若干检查点的目录, 目录结构一般如下

checkpoint (可以文本方式读取和修改)
???-n.data-xxxxxx-of-xxxxxx 若干 n 代表检查点的序号  xxxxxx指示数据数据存储有关的东西 不影响
???-n.index 若干 n 代表检查点的序号  xxxxxx指示数据数据存储有关的东西 不影响

checkpoint_directory+???一般就是prefix的形式

checkpoint_prefix = os.path.join(checkpoint_directory,"ckp11t")
checkpoint.save(file_prefix=checkpoint_prefix)

那么检查点就会保存在 checkpoint_directory的目录下,以ckp11t-n开头,这也就是prefix的意义

2. latest_checkpoint

tmp=tf.train.latest_checkpoint(checkpoint_directory)

latest_checkpoint(checkpoint_directory)的机制是,读取checkpoint 文件,
获取checkpoint_directory目录下的最新的检查点的地址(不包括后缀和点号)。
如第100个检查点保存后读取,则返回“…\training_checkpoints\ckp11t-100"
的地址。但是不止于此,同时会检查对应的文件是否存在,如果不存在,报错,返回空地址,不会抛出异常中断程序。

一般的
如果自行修改了checkpoint 文件中的检查点序号,就可以指定目标检查点,从而实现从已有的任意位置继续,覆盖后续的检查点。

不一般的
如果修改了checkpoint 文件中的检查点序号为未出现的检查点序号,比如100改成10000,但是才训练了101次,根本没有10000检查点,相应的,将第100个检查点文件名中的100改成10000,依旧可以正常读取训练,同时,下一次依旧从101开始保存,并非10001。所以检查点信息保存在了检查点文件中,并非记录在了文件名里,所以即便如此不正常的操作,也不会破坏训练。

3.restore

status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))

装载检查点

如果地址为空,则不装载,保存时从第一个检查点自动保存。
如果地址不正确,无此目录,则报错中断(此处不是抛出异常,因为默认情况下,会接收latest_checkpoint传参,只会出现有效和None的情况,所以TF官方似乎就没再做特有的异常处理)
好玩的来了, 如果地址是另外一个文件夹下的检查点的地址,或者已经有100个检查点了,指定了第50个检查点,那么其实并不复杂,保存还是会在此检查点之上保存。如果装载的是第50个检查点,则从51继续保存。

你可能感兴趣的:(python,机器学习,深度学习)