tensorflow2.4 用checkpoint保存网络,读取后继续训练!

Tensorflow2学习记录

提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加
例如:第一章 Tensorflow2网络的保存与恢复后继续训练。


提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • Tensorflow2学习记录
  • 前言
  • 一、官方帮助在哪里?
  • 二、学习过程
    • 开始理解checkpoint
    • 首先是声明阶段
    • 其次是保存阶段
    • 调用阶段
    • 自己的尝试
  • 最终的实现
  • 总结


前言

最近计划深入学习一下机器学习,奈何笔记本不给力,训练起来超级慢。
又必须关机带着电脑走。晚上也不想把电脑设置为高性能不关机的模式,
那样太吵了。所以一直在找怎么将继续训练的代码写出来,这样我就
可以随时停止训练,然后背电脑回家或者可以用电脑做别的事情,没事了
就继续接着训练。这样虽然效率低一点,但这样才是常态啊。不是每个人
都有专门的机器学习服务器呀!

原来在matlab上用过机器学习,刚开始具体接触Python下的机器学习,
听说TensorFlow2比较好用,就从它入手。结果网络上铺天盖地的都是
TensorFlow1的帖子,好像找到点有关系的帖子,但是要订阅,每个月
几十上百块。都收钱了,博客还有啥意义?


提示:以下是本篇文章正文内容,下面案例可供参考

一、官方帮助在哪里?

最害怕看到最后还没有找到自己的答案。
所以下面直接给出我用的官方链接,您可以先看看,
如果看了之后不会,您老再向后看看我是怎么弄的。
官方帮助链接:https://tensorflow.google.cn/guide/checkpoint?hl=en
github实例链接:https://github.com/tensorflow/docs/blob/master/site/en/guide/checkpoint.ipynb

github上的链接可能不好下载,但偶尔还可以看。所以下面就主要根据github上的实例介绍恢复并重新训练的功能。

二、学习过程

首先,你应该准备好网络了。你可以随便找个网络比如网络上到处都有的MNIST手写数字识别的网络。
最开始我看的是简单粗暴TensofFlow2里面的一个例子。
链接地址:https://tf.wiki/zh_hans/basic/tools.html#tf-train-checkpoint

开始理解checkpoint

我们这里的目标是临时保存网络,等有空了再把网络恢复,然后重新接着训练,这里的接着是最重要的。

它给出的例子如下

import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoader

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()

def train():
    model = MLP()
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
    num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)
    checkpoint = tf.train.Checkpoint(myAwesomeModel=model)      # 实例化Checkpoint,设置保存对象为model
    for batch_index in range(1, num_batches+1):                 
        X, y = data_loader.get_batch(args.batch_size)
        with tf.GradientTape() as tape:
            y_pred = model(X)
            loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
            loss = tf.reduce_mean(loss)
            print("batch %d: loss %f" % (batch_index, loss.numpy()))
        grads = tape.gradient(loss, model.variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
        if batch_index % 100 == 0:                              # 每隔100个Batch保存一次
            path = checkpoint.save('./save/model.ckpt')         # 保存模型参数到文件
            print("model saved to %s" % path)


def test():
    model_to_be_restored = MLP()
    # 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored
    checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      
    checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
    y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)
    print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))


if __name__ == '__main__':
    if args.mode == 'train':
        train()
    if args.mode == 'test':
        test()

通过上面的代码,我知道在TF2里面有个叫checkpoint的东西,可以临时保存神经网络。
可以后期把前面保存的神经网络恢复后继续使用。
仔细研究后发现主要是这几句话,大家注意看代码里面的注释
具体流程是:

  • 声明阶段
    1. 先定义网路和优化器
    2. 定义一个checkpoint,实例化要保存啥内容。
  • 保存阶段
    1. 在保存的时候,一般是一个或几个batch训练完成之后
  • 调用阶段
    1. 在需要的时候调用之前训练的网络,继续训练。

首先是声明阶段

我习惯用声明,其实应该是实例化,以后改。

  • 先声明一个网络model,MLP就是一个多层感知器网络。
  • 又声明一个优化器optimizer,如果要保存优化器,就必须在声明阶段就定义。
  • 最后声明一个checkpoint。
    这个就很重要了,括号里面的东西我还没有找到明确的解释,只知道是一对对的参数。
    比如下面代码里面的(myAwesomeModel=model),
    表示要保存的是model,就是那个多层感知机。
    而myAwesomeModel就是保存的名称,
    恢复时也要用myAwesomeModel这个名字。
# 代码里面的注释完全看不清啊。
model = MLP() 
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
# 实例化Checkpoint,设置保存对象为model
checkpoint = tf.train.Checkpoint(myAwesomeModel=model)     

注意: 定义完成以后我们已经做好了一切准备,只需要在保存的时候直接调用就好。

其次是保存阶段

保存阶段的代码如下

  • 先定一个保存时间间隔,毕竟每次训练完一个batch就保存一下似乎没有任何意义。
  • 这里的时间间隔是100个batch
  • 调用checkpoint.save开始保存,括号里面指定了保存的目录和文件名前缀。
    没错,’./save/model.ckpt’ 这个字符串里面 ./save/ 是指当前目录下的一个叫save的目录
    model.ckpt 是文件名的前缀,每调用一次保存就会在后面增加编号之类的后缀。
    这里保存了一个index文件和一个data文件。
  • checkpoint.save是有返回值得,返回值是保存的目录。
    所以,最后有一句打印保存的目录信息。
		# 每隔100个Batch保存一次
        if batch_index % 100 == 0:  
        	# 保存模型参数到文件                            
            path = checkpoint.save('./save/model.ckpt')         
            print("model saved to %s" % path)

调用阶段

保存完成后就可以调用了,在上面的例子中直接在测试阶段恢复了保存的网络。
实际意义不大,但借鉴意义很强。主要是函数定义的网络的定义域在函数里面,
其他函数看不到,所以需要恢复网络。

  • 第一句同样是定义网络,这里给了个不同的名字model_to_be_restored ,
    用于和我们保存的model区分
  • 第二句实例化一个checkpoint,这次的实例化必须和声明阶段里面的实例化完全相同
    这个相同是指实例化里面你指定的保存和恢复项目要相同。声明中我们定义保存
    神经网络,所以这里声明时也是只有神经网络。如果声明阶段你定义保存了神经网络
    和优化器,这里的恢复阶段也需要定义神经网络和优化器。
    (myAwesomeModel=model_to_be_restored) 是指把保存的myAwesomeModel
    网络的各种参数恢复给model_to_be_restored
  • 第三句就是利用restore恢复数据了,里面的**(tf.train.latest_checkpoint(’./save’))**表示
    从目录 ./save 目录里面恢复最后一次保存的网络。
    model_to_be_restored = MLP()
    # 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored
    checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)
    checkpoint.restore(tf.train.latest_checkpoint('./save'))

自己的尝试

我也写了个网络,由于我不需要在测试阶段调用,所以把恢复网络直接写在保存之前。
顺序是

  • 声明阶段
  • 恢复阶段
  • 保存阶段

调整后的代码如下

  • 实例化一个网络
  • 实例化一个checkpoint,指定要保存和恢复的内容,这里只保存和恢复网络。
  • 定一个检查器
  • 如果检查器为真,就是如果有保存网络的目录并且目录里有保存的结果
  • 就恢复最后一次保存的网络,因为普通情况下,最后一次保存的网络loss最小、精度最高。
  • 最后一行是在合适的时候保存网络。
# 声明网络和优化器
    model = MLP()
    # 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored
    checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
    # 从模型恢复参数
    cpkt = tf.train.get_checkpoint_state(savedir)
    if cpkt and cpkt.model_checkpoint_path:
    	checkpoint.restore(tf.train.latest_checkpoint('./save'))
    .
    . # 中间的其他代码
    .
	# 保存模型
	checkpoint.save('./save/model.ckpt') 

上面的代码我觉得很好啊,声明、恢复加保存,完美的循环。
可是恢复网络缺失做到了,但是每次还是从头开始训练,完全没有继续训练的意思。
有的帖子说是每次恢复之前重新指定了学习率 ,我也该过学习率,但根本没效果。

最终的实现

尝试了很多办法,最终还是选择了把官方的办法一个个的尝试,终于找到一个靠谱的方法。

  • 不是直接利用checkpoint,而是利用checkpoint manager来保存和恢复
  • 实例化checkpoint的参数也不一样了,不太清楚tf.Variable(1) 的作用。
  • 多实例化了一个manager 调用了一个 CheckpointManager 指定保存 checkpoint,
    保存目录,最大保存文件数量。
  • 恢复网络也变了,用 manager.latest_checkpoint 表示最后一个保存的网络。
  • 最后的保存也用了manager.save() 方法保存。

用这个套路保存的网络可以正常恢复后接着训练。大家看看有没有异常情况。

	# 初始化网络
    model = CNNmodel()
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                     optimizer=optimizer,
                                     lidarmodel=model)
    manager = tf.train.CheckpointManager(checkpoint, './save', max_to_keep=3)
    # 恢复网络
    # 检查是否存在保存目录,保存目录里面是否有信息,装载最后一次的训练结果。继续训练
    cpkt = tf.train.get_checkpoint_state(savedir)
    if cpkt and cpkt.model_checkpoint_path:
        checkpoint.restore(manager.latest_checkpoint)
        print('Successfully loaded:', cpkt.model_checkpoint_path)
    else:
        print('Could not find old network weights!')
        
	#
	# 这里是其他代码
	#
	
	# 保存网络
        checkpoint.step.assign_add(1)
        print("batch %d/%d: loss %f"
              % (batch_index, num_epochs, loss.numpy()))
        # 开始保存网络
        if int(checkpoint.step % 10) == 0:
            save_path = manager.save()
            print("Saved checkpoint for step {}: {}".format(int(checkpoint.step),
                                                            save_path))

总结

以上就是我先保存网络,恢复后继续训练的例子。
Python版本3.8
tensorflow版本2.4

你可能感兴趣的:(机器学习,机器学习)