欢迎关注简书:公输睚信
上一篇文章 TensorFlow 训练 CNN 分类器 中说明了训练简单 CNN 模型的整个过程,并在训练结束后使用 .save
函数来保存训练的结果,其后通过使用 tf.train.import_meta_graph
和 .restore
函数来导入模型进行推断。本文承接上文,对模型保存与恢复做一个总结。
总的来说,模型在保存和恢复时最重要的是留下数据接口,方便使用时传入数据和获取结果。TensorFlow 中常用的模型保存格式为 .ckpt 和 .pb,下面分别进行详细说明。
.ckpt 格式保存与恢复都很简单,具体可参考 TensorFlow 训练 CNN 分类器。
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction') <-- 出口(具体输出依情况而定,下同)
···
saver = tf.train.Saver()
···
with tf.Session() as sess:
··· <-- 训练过程
saver.save(sess, './xxx/xxx.ckpt') <-- 模型保存
如上述代码所示,假设你定义了一个 TensorFlow 模型,数据入口由占位符 inputs
给定,结果出口由张量 prediction
给定。通过语句 saver = tf.train.Saver()
定义了模型保存的一个实例对象 saver
,当模型训练结束之后只需要简单的一条语句:
saver.save(sess, path_to_model.ckpt)
就把训练结果保存到了指定的路径。
以上代码之所以把变量 inputs
和 prediction
单独列出,一方面是因为它们是模型 Graph 的起点和终点(戏称为数据入口、出口),另一方面的原因是它们被特别的指定了名称,因而在模型恢复时可以通过它们的名称而得到 Graph 中对应的节点。
当你需要导入模型进行推断时,只需要通过张量名获取数据入口和出口,然后传入数据即可:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
saver.restore(sess, './xxx/xxx.ckpt')
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
prediction = tf.get_default_graph().get_tensor_by_name('prediction:0')
pred = sess.run(prediction, feed_dict={inputs: xxx}
保存为 .ckpt 模型的一个好处是,当需要继续训练时,只需要将训练过的模型结果导入,然后在这个基础上再继续训练。而下面的 .pb 格式则不能继续训练,因为这种格式保存的模型参数都已经转化为了常量(而不再是变量)。
.pb 格式模型保存与恢复相比于前面的 .ckpt 格式而言要稍微麻烦一点,但使用更灵活,特别是模型恢复,因为它可以脱离会话(Session)而存在,便于部署。
与 .ckpt 格式模型保存类似,首先定义数据入口、出口:
from tensorflow.python.framework import graph_util
···
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs')
···
prediction = tf.nn.softmax(logits, name='prediction')
···
with tf.Session() as sess:
··· <-- 训练过程
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['prediction'] <-- 参数:output_node_names,输出节点名
)
with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
然后通过函数 graph_util.convert_variables_to_constants
将模型固话,使得所有变量转化为常量,之后写入到指定的路径完成模型保存过程。
.pb 格式模型恢复自由度较大,不需要在会话里进行操作,可以独立存在:
import os
def load_model(path_to_model.pb):
if not os.path.exists(path_to_model.pb):
raise ValueError("'path_to_model.pb' is not exist.")
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model.pb, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return model_graph
模型导入之后,便可以获取数据入口和出口,然后进行推断:
model_graph = load_model('./xxx/xxx.pb')
inputs = model_graph.get_tensor_by_name('inputs:0')
prediction = model_graph.get_tensor_by_name('prediction:0')
with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
···
pred = sess.run(prediction, feed_dict={inputs: xxx}
一般情况下,为了便于从断点之处继续训练,模型通常保存为 .ckpt 格式,而一旦对训练结果很满意之后则可能需要将 .ckpt 格式转化为 .pb 格式。转化方法很简单,只需要综合前面的一、二两步即可:
from tensorflow.python.framework import graph_util
with tf.Session() as sess:
# Load .ckpt file
saver = tf.import_meta_graph('./xxx/xxx.ckpt.meta')
saver.restore(sess, './xxx/xxx.ckpt')
# Save as .pb file
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['prediction'] <-- 输出节点名,以实际情况为准
)
with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
预告:下一篇文章将简单介绍 tensorflow.contrib.slim
的应用,敬请关注!