Tensorflow 模型保存与恢复(3)保存模型到单个文件中

保存模型到单个.pb文件中

前面两篇介绍了使用Saver 和SavedModel保存模型:
Tensorflow 模型保存与恢复(1)使用tf.train.Saver()
Tensorflow 模型保存与恢复(2)使用SavedModel

1.保存模型

上面两种方法保存的模型的变量都是单独保存在一个文件中的,模型的图元数据则是保存在另一个文件中。有时候希望将模型的元数据和变量值保存到同一个文件中,即最后的模型只有一个.pb文件。本篇介绍如何将TensorFlow的模型保存到单个文件中,以及读取恢复模型。
基本思路主要是利用GraphDef对象,使用convert_variables_to_constants方法将变量转化为常量,然后序列化保存到磁盘。基本步骤如下:
首先获取GraphDef对象:

input_graph_def = graph.as_graph_def()

设定需要导出的节点,可以将模型的输入输出节点导出,在恢复模型进行预测的时候可以据此获取输入输出的tensor:

 output_node_name = 'x_input,conv'

使用convert_variables_to_constants方法将GraphDef对象中的变量转化为常量,并返回一个新的GraphDef:

output_graph_def = tf.graph_util.convert_variables_to_constants(sess=sess,
                                                                    input_graph_def=input_graph_def,
                                                                    output_node_name=output_node_name.split(','))

使用tf.gfile保存序列化后的GraphDef到磁盘:

output_graph_def_filename = './frozen_model.pb'
with tf.gfile.GFile(output_graph_def_filename, 'wb')as f:
        f.write(output_graph_def.SerializeToString())

2.恢复模型

使用模型的时候,首先使用tf.gfile读取模型文件并解析:

frozen_model_name = './frozen_model.pb'
with tf.gfile.GFile(frozen_model_name, 'rb') as f:
        restored_graph_def = tf.GraphDef()
        restored_graph_def.ParseFromString(f.read())

然后使用import_graph_def方法导入GraphDef对象到默认的图中:

with tf.Graph().as_default() as graph:
        # import_graph_def to import a serialized GraphDef and extract the tensor, op,
        # then place them to the default graph
        tf.import_graph_def(graph_def=restored_graph_def,
                                 #input_map=None,
                                 #return_elements=None,
                                 name=""  # the name position parameter can't be ignore
                                 )

之后可以根据保存模型时候导出的节点名称获取对应的tensor:

input_tensor = graph.get_tensor_by_name('x_input:0')
conv_tensor = graph.get_tensor_by_name('conv:0')

这样就完成了保存模型到单个文件中,并从模型文件中恢复模型进行预测的过程。
如果是希望将训练过程中得到的checkpoint文件转为单个.pb文件,可以先使用tf.trian.Saver保存模型中介绍的方法restore模型,然后按照上面的步骤保存模型到单个文件中。

详细完整的代码见:github 保存模型到单个文件中

你可能感兴趣的:(TensorFlow)