Tensorflow2.0 继续训练自己未训练完的模型(tf.train.Checkpoint())

在我们使用tensorflow做深度学习的时候,需要用大量的数据来训练模型。但正因为数据量大如果电脑的性能不是很好的话在训练模型的时候我们的电脑是没有剩余的内存供我们使用的,但模型训练又需要花费很多时间,如果我们需要用电脑做其他事情的话就必须停止训练模型,但停止以后再重新开始从头训练的话又会花费很多的时间,所以我们要在停止训练时保存的模型参数的那个阶段继续我们的训练。

模型保存

首先我们要知道要想继续我们的训练就必须保存好我们之前训练好的模型参数,这样我们的程序才能使用现有的参数继续来训练模型而不是再随机生成参数那种大范围的拟合。
这里我保存模型的方法是用的 tf.train.Checkpoint()这个函数。

checkpoint_path = './checkpoint/train'
ckpt = tf.train.Checkpoint(transformer=transformer,optimizer=optimizer)
# ckpt管理器
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

tf.train.Checkpoint()这个函数里面有两个参数一个是你要训练的模型,第二个是你的优化器。tf.train.CheckpointManager()函数里面有三个参数第一个是tf.train.Checkpoint()设定好你要保存的参数,第二个参数是你保存的路径,第三个参数是你要保存的模型数量。

当我从头开始训练的时候,这是训练后保存的模型参数和训练的准确度,稍后我们会用这个参数模型来演示中断后再开始训练的结果。
Tensorflow2.0 继续训练自己未训练完的模型(tf.train.Checkpoint())_第1张图片

Tensorflow2.0 继续训练自己未训练完的模型(tf.train.Checkpoint())_第2张图片

重新加载模型

checkpoint = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path1))

这两个函数就是重新加载我们的模型,需要注意的是tf.train.Checkpoint()里面的两个参数名要和你保存模型时的名字一样才行。

我们用上次保存下来的模型接着训练结果如下:
Tensorflow2.0 继续训练自己未训练完的模型(tf.train.Checkpoint())_第3张图片
第一次从头开始训练模型时准确度为0.22,我们训练三个循环后保存模型时准确度是0.76。
当我们重新训练加载已有的模型时初次训练的准确度是0.78。正好接着上次保存后的准确度。

其实从头开始训练和接着上次的继续训练就只加了这两行代码。让模型迭代第一次时有初始的参数可以用。

checkpoint = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path1))

你可能感兴趣的:(tensorflow,深度学习)