TensorFlow 模型保存与复用的实例讲解

一、环境

Python 3.7.3 (Anaconda 3)

TensorFlow 1.14.0

二、方法

TensorFlow 模型保存与恢复的方法主要由 tf.train.Saver 类提供,同时也结合一些模型图加载等方法。

相关方法的官网说明:
https://www.tensorflow.org/guide/saved_model?hl=zh-cN
https://tensorflow.google.cn/api_docs/python/tf/train/import_meta_graph
https://tensorflow.google.cn/api_docs/python/tf/train/Saver

1、模型保存

该阶段一般被称为 train 阶段,主要包括:

  • 构建模型
  • 训练模型
  • 保存模型

其中保存模型主要通过 tf.train.Saver 类对象的 save 方法来完成,在指定的保存模型的目录下会生成四种类型的文件:

saved_models_directory:
    ******.meta
    ******.index
    ******.data-00000-of-00001
    checkpoint

其中,
(1)checkpoint 文件:记录了最近 5 次(创建 tf.train.Saver 类对象时,参数 max_to_keep 的默认值)训练所保存的模型文件的一个列表,该文件可以通过普通文本编辑器打开查看;
(2).meta 文件:保存了模型的图(网络结构);
(3).index 文件和 .data-00000-of-00001 文件:(暂未找到官方说明文档,待补充,现有很多资料都指出这两个文件保存训练好的模型的变量值(权重、偏置等),但都未提供官方具体说明,因此这里不在人云亦云,等找到官方说明文档确认后再更新)

需要说明的是:

在训练大量批量的模型时,如果创建 tf.train.Saver() 类对象的参数 max_to_keep 保持默认值 5,那么保存模型的目录下会产生一个 checkpoint 文件,同时后 3 种文件会在模型保存目录下生成 5 份,即最近保存的 5 次模型)

2、模型复用

一般称其为 inference 阶段,主要分为:
(1)构建图(简单理解为创建网络结构,具体有如下两种方式)

  • 手动重新构建图

  • 自动从保存文件中恢复图

(2)恢复变量(权重、偏置、超参数等)
(3)运行相关操作

三、实例讲解

官网参考:https://www.tensorflow.org/guide/saved_model?hl=zh_cn

1、模型保存
>>> import tensorflow as tf
# 创建变量并初始化幅值为 0
>>> v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
>>> v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
# 对两个变量的数值分别加、减 1
>>> inc_v1 = v1.assign(v1+1)
>>> inc_v2 = v2.assign(v2-1)
# 添加保存和复用变量的操作
>>> init_op = tf.global_variables_initializer()
# 创建 Saver 对象
>>> saver = tf.train.Saver()
# 创建会话
>>> with tf.Session() as sess:
# 初始化变量
...     sess.run(init_op)
# 执行操作
...     inc_v1.op.run() # 复制操作有 op
...     inc_v2.op.run()
# 保存变量到指定目录
...     save_path = saver.save(sess,'/saved_models_directory/model.ckpt')
...     print("Model saved in path: %s" % save_path)
... 
Model saved in path: /saved_models_directory/model.ckpt

然后会在指定目录下生成如下四个文件:
model.ckpt.meta
model.ckpt.index
model.ckpt.data-00000-of-00001
checkpoint

2、模型复用

方法一:手动重新构建图

# 重置图
>>> tf.reset_default_graph()
# 构建与训练阶段相同的模型及变量
>>> v1 = tf.get_variable("v1",shape=[3])
>>> v2 = tf.get_variable("v2",shape=[5])
# 创建 tf.train.Saver() 对象
>>> saver = tf.train.Saver()
>>> with tf.Session() as sess:
# 恢复模型的相关变量
...     saver.restore(sess, '/saved_models_directory/model.ckpt') # 注意第二参数传入的检查点文件没有后缀名
...     print("Model restored.")
# 执行相关操作
...     print("v1 : %s" % v1.eval())
...     print("v2 : %s" % v2.eval())
... 

Model restored.
v1 : [1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]

PS:手动重建图(网络结构),恢复保存的训练好的变量值,一般不会出错,但相对麻烦!

方法二:自动从保存文件中恢复图
直接从保存的文件中恢复训练好的模型的网络结构以及相关变量的数值,相对简便快捷,但是对于初学者来说,可能会遇到一些问题!

首先,可能会遇到的一个错误是忘记在恢复变量值之前先恢复 TensorFlow 的图,也就是训练好的模型的网络结构

>>> tf.reset_default_graph()
>>> with tf.Session() as sess:
...     saver.restore(sess,'/saved_models_directory/model.ckpt')
...     print("Model restored.")
...     all_vars = tf.global_variables()
...     for v in all_vars:
...             print(v.name)
... 
Traceback (most recent call last):
  File "", line 2, in <module>
  File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1286, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1098, in _run
    raise RuntimeError('The Session graph is empty.  Add operations to the '
RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

错误提示:会话的图是空的,需要先创建图或者导入已经保存模型的图,因此,需要注意的是在恢复训练好的变量值之前,必须要先恢复图(网络结构)

其次,需要特别注意 TensorFlow 从保存文件中恢复出的模型的变量名与训练阶段模型中的变量名有区别,在训练阶段模型的变量名的后面添加了“:0”后缀,如:“v1”变为“v1:0”,因此初学者可能会遇到一些错误

下面实例可以观察到这种恢复变量的后缀变化

>>> tf.reset_default_graph()
>>> # latest_model_file = tf.train.latest_checkpoint("/saved_models_directory/")
>>> saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta")
>>> with tf.Session() as sess:
...     saver.restore(sess,'/saved_models_directory/model.ckpt')
...     print("Model restored.")
...     all_vars = tf.global_variables()
...     for v in all_vars:
...             print(v.name)
... 
Model restored.
v1:0
v2:0

可以发现从保存的模型中恢复的变量的名称后会多了一个“:0”,所以查看变量 v1 和 v2 时需要使用新的变量名“v1:0”和“v2:0”

因此,在恢复模型、恢复变量之后,就可以通过新的的变量名执行相关操作了,可以通过 v1.eval() 或 sess.run(v1) 等方法打印输出相关值查看

>>> tf.reset_default_graph()
>>> saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta")
>>> with tf.Session() as sess:
...     saver.restore(sess,'/saved_models_directory/model.ckpt')
...     print("Model restored.")
...     v1 = tf.get_default_graph().get_tensor_by_name('v1:0')
...     v2 = tf.get_default_graph().get_tensor_by_name('v2:0')
...     print("v1 : %s" % v1.eval())
...     print("v2 : %s" % sess.run(v2))
... 
Model restored.
v1 : [1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]

也可以通过如下方法执行操作

>>> tf.reset_default_graph()
>>> saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta") # 注意 .meta 文件中保存模型的图结构
>>> with tf.Session() as sess:
...     saver.restore(sess,'/saved_models_directory/model.ckpt') # 注意第二参数传入的检查点文件没有后缀名
...     print("Model restored.")
...     print (sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
...     print (sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))
... 
Model restored.
[1. 1. 1.]
[-1. -1. -1. -1. -1.]

错误实例:没有使用模型复用后自动添加后缀的新变量名

>>> with tf.Session() as sess:
...     saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta")
...     saver.restore(sess, '/saved_models_directory/model.ckpt')
...     print("Model restored.")
...     print("v1 : %s" % v1.eval())
...     print("v2 : %s" % v2.eval())
... 
# 报错提示如下:
Traceback (most recent call last):
  File "", line 5, in <module>
NameError: name 'v1' is not defined

你可能感兴趣的:(TensorFlow基础)