前面两篇介绍了使用Saver 和SavedModel保存模型:
Tensorflow 模型保存与恢复(1)使用tf.train.Saver()
Tensorflow 模型保存与恢复(2)使用SavedModel
上面两种方法保存的模型的变量都是单独保存在一个文件中的,模型的图元数据则是保存在另一个文件中。有时候希望将模型的元数据和变量值保存到同一个文件中,即最后的模型只有一个.pb文件。本篇介绍如何将TensorFlow的模型保存到单个文件中,以及读取恢复模型。
基本思路主要是利用GraphDef
对象,使用convert_variables_to_constants
方法将变量转化为常量,然后序列化保存到磁盘。基本步骤如下:
首先获取GraphDef
对象:
input_graph_def = graph.as_graph_def()
设定需要导出的节点,可以将模型的输入输出节点导出,在恢复模型进行预测的时候可以据此获取输入输出的tensor:
output_node_name = 'x_input,conv'
使用convert_variables_to_constants
方法将GraphDef
对象中的变量转化为常量,并返回一个新的GraphDef
:
output_graph_def = tf.graph_util.convert_variables_to_constants(sess=sess,
input_graph_def=input_graph_def,
output_node_name=output_node_name.split(','))
使用tf.gfile保存序列化后的GraphDef
到磁盘:
output_graph_def_filename = './frozen_model.pb'
with tf.gfile.GFile(output_graph_def_filename, 'wb')as f:
f.write(output_graph_def.SerializeToString())
使用模型的时候,首先使用tf.gfile读取模型文件并解析:
frozen_model_name = './frozen_model.pb'
with tf.gfile.GFile(frozen_model_name, 'rb') as f:
restored_graph_def = tf.GraphDef()
restored_graph_def.ParseFromString(f.read())
然后使用import_graph_def
方法导入GraphDef
对象到默认的图中:
with tf.Graph().as_default() as graph:
# import_graph_def to import a serialized GraphDef and extract the tensor, op,
# then place them to the default graph
tf.import_graph_def(graph_def=restored_graph_def,
#input_map=None,
#return_elements=None,
name="" # the name position parameter can't be ignore
)
之后可以根据保存模型时候导出的节点名称获取对应的tensor:
input_tensor = graph.get_tensor_by_name('x_input:0')
conv_tensor = graph.get_tensor_by_name('conv:0')
这样就完成了保存模型到单个文件中,并从模型文件中恢复模型进行预测的过程。
如果是希望将训练过程中得到的checkpoint文件转为单个.pb文件,可以先使用tf.trian.Saver保存模型中介绍的方法restore模型,然后按照上面的步骤保存模型到单个文件中。
详细完整的代码见:github 保存模型到单个文件中