(以下 tensorflow 简写为 tf)
通过将 tf 训练的模型保存到文件中,方便模型在预测时选择模型进行预测。
tf 中通过使用 tf.train.Saver 类,调用其 save 方法将模型进行保存
实现方式如下:
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 14 14:22:02 2019
@author: JustMo
"""
import tensorflow as tf
import os
with tf.variable_scope('test'):
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1.0))
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2.0))
result = v1 + v2
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, '/path/to/model/model.ckpt')
因为 tf 会将计算图的结构和图上的参数取值分开保存,所以虽然在保存时,只指定了一个文件;但是调用这个 API 之后会生成4个文件,tf 的持久化就是通过这4个文件完成的:
维护了一个由 tf.train.Saver 类持久化的所有 tf 模型文件的文件名。当某个保存的 tf 模型文件被删除时,这个模型对应的文件名也会从 checkpoint 中移除。
类型定义:
#保存最新的 tf 模型文件的文件名
model_checkpoint_path: '/path/to/model/model.ckpt'
#列出了当前还没有被删除的所有 tf 模型文件的文件名
all_model_checkpoint_path: '/path/to/model/model.ckpt'
tf 通过对元图(MetaGraph)来记录计算图节点中的信息、运行计算图中节点所需要的元数据。
MetaGraphDef主要记录6类信息,属性如下:
meta_info_def---记录了 tf 计算图中的元数据以及 tf 程序中所有使用到的运算方法信息
graph_def---主要记录 tf 计算图上的节点信息,即运算的连接结构
saver_def---记录了模型持久化时需要用到的一些参数,如保存到文件的文件名、保存操作和加载操作的名称以及、保存频率、清理历史记录...
collection_def---是一个从集合名称到集合内容的映射
signature_def---应该是导出模型后,.index文件的部分,建立张量名到张量的映射。
asset_file_def---单个文件或具有相同名称的一组分片文件的资源文件def,应该是保存权重的.data文件,根据AssetFileDef属性在.data中寻找对应的权重参数。
因为 tf 生成的 model.ckpt.meta 是个二进制文件,无法直接查看。不过 tf 提供了 export_meta_graph 函数以json的格式导出 meta文件。调用方式如下:
saver.export_meta_graph('/path/to/model/model_test.ckpt.meta.json', as_text=True)
这时就可以方便的查看模型的结构了。
这两个模型文件,主要是针对于持久化 tf 中变量的取值。通过 tf.train.saver 得到 model.ckpt.index 以及 model.ckpt.data-xxxxx-of-xxxxx 文件保存了所有变量的取值。
model.ckpt.data 是通过 SSTable 格式:(key, value)列表格式存储的。
tf 中提供了 tf.train.NewCheckpointRrader 来查看保存的变量信息。调用方式如下:
#通过NewCheckpointReader读取checkpoint中所有保存的变量
reader = tf.train.NewCheckpointReader('/path/to/model/model_test.ckpt')
#获取一个从变量名到变量维度的字典类型的所有变量列表
golbal_varibles = reader.get_variable_to_shape_map()
for variable_name in golbal_varibles :
print(variable_name , golbal_varibles[variable_name])
#获取变量空间下名称为v1的变量的取值
print('value for v1 is',reader.get_tensor('test/v1'))
得到结果如下:
test/v1 [1]
test/v2 [1]
value for v1 is [ 1.]
加载模型的代码和保存的代码差不多。在加载模型的程序中,也需要先定义 tf 计算图上的所有运算,并声明 tf.train.Saver 类,然后通过 saver 的 restore 来进行加载。
不同点在于,在加载模块中,没有运行变量初始化的过程,而是将变量的值通过已经保存的模型加载进来。
如果不希望重复定义图上的运算,也可以使用 import_meta_graph 直接通过 meta 加载持久化的图。
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 14 16:46:54 2019
@author: JustMo
"""
import tensorflow as tf
import os
with tf.variable_scope('test',reuse=False):
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1.0))
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2.0))
result = v1 + v2
saver = tf.train.Saver()
#不用重复定义图上运算,直接从持久化中获取模型
saver2 = tf.train.import_meta_graph('/path/to/model/model_test.ckpt.meta')
with tf.Session() as sess:
#通过meta保存的信息,来计算result
saver.restore(sess, '/path/to/model/model_test.ckpt')
print(sess.run(result))
saver2.restore(sess, '/path/to/model/model_test.ckpt')
#通过张量名称获取张量值
print(sess.run(tf.get_default_graph().get_tensor_by_name("test/add:0")))
前面的方式都是从持久化模型中完全加载模型,但是在实际的运用中,很有可能只使用部分变量,如迁移学习中,这个之后之前的方式就不适用了。
为了保存或者加载部分变量,在声明 tf.train.Sever 类时,可以提供一个列表来指定需要保存或者加载的变量:
saver = tf.train.Saver([v1])
使用 tf.train.Saver 会保存 tf 程序所需要的全部信息,但是在 测试、离线预测、迁移学习时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可。不需要变量的初始化、模型保存等辅助结点的信息。且模型保存会将变量取值以及计算图结构分成两个文件保存,这在取值的时候也不方便,所以有 conver_variables_to_constants 函数将计算图中的变量以及取值通过常量的方式来保存,这样整个 tf 计算图就可以统一存放在一个文件中。
保存:
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 14 17:48:18 2019
@author: JustMo
"""
import tensorflow as tf
from tensorflow.python.framework import graph_util
with tf.variable_scope('test'):
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1.0))
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2.0))
result = v1 + v2
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
#导出当前计算图的GraphDef部分,这部分保存的是计算的连接,只需要这部分就可以计算从输入到输出
graph_def = tf.get_default_graph().as_graph_def()
#将计算图中的变量以及变量的取值转化为常量,并提供需要的计算节点名,可以去掉不需要的节点,因为其他的和计算无关的节点没有必要导出并保存了。
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['test/add'])
#保存模型
with tf.gfile.GFile('/path/to/model/combined_model_test.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
加载:
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 14 17:59:40 2019
@author: JustMo
"""
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
#需要加载的模型名称
model_filename = '/path/to/model/combined_model_test.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
#进行解析
graph_def.ParseFromString(f.read())
#通过加载模型中的图到当前的图中。return_elements给出了返回张量的名称,这个时候要使用张量名,不是节点名
result = tf.import_graph_def(graph_def, return_elements=['test/add:0'])
print(sess.run(result))