Tensorflow 模型保存与恢复(2)使用SavedModel

使用SavedModel 保存和恢复模型

本篇介绍使用SavedModel进行模型的保存与恢复。

其他相关:
Tensorflow 模型保存与恢复(1)使用tf.train.Saver()
Tensorflow 模型保存与恢复(3)保存模型到单个文件中

SavedModel 是一种跨语言的序列化格式(protobuf),可以保存和加载模型变量、图和图的元数据,适用于将训练得到的模型保存用于生产环境中的预测过程。由于跨语言的特性,应用时,可以使用一种语言保存模型,如训练时使用Python代码保存模型;使用另一种语言恢复模型,如使用C++代码恢复模型,进行前向推理,提高效率。
·SavedModel·可以为保存的模型添加签名·signature·,用于保存指定输入输出的graph, 另外可以为模型中的输入输出tensor指定别名,这样子使用模型的时候就不必关心训练阶段模型的输入输出tensor具体的name是什么,讲模型训练和部署解耦,更加方便。

1.保存模型

使用SavedModel保存模型时使用tf.saved_model.builder.SavedModelBuilder来实现,实例化一个builder,传入模型保存的路径:

builder = tf.saved_model.builder.SavedModelBuilder(model_path)

确定输入输出dict,这一步可以看成是给模型中的输入输出tensor取一个别名,可以在预测的时候直接使用这个别名,而不必关心其在原本的训练代码中的名字:

inputs = {'input0': tf.saved_model.utils.build_tensor_info(x_input)}
outputs = {'output0': tf.saved_model.utils.build_tensor_info(conv)}

创建signature:

my_signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs)

添加图和变量:

builder.add_meta_graph_and_variables(sess, ['MODEL_TRAINING'], signature_def_map={'my_signature': my_signature})

最后保存模型:

builder.save()

这样就使用SavedModel完成了模型的保存,保存的结果是两个文件(夹),一个是.pb文件,是序列化的tensorflow::SavedModel,包含了一个或者多个graph及其元数据;另一个是Variables文件夹,里面保存了序列化的Variables。

2.恢复模型

恢复模型的时候,使用tf.saved_model.loader.load()方法,加载模型:

with tf.Session() as sess:
        # load model
        meta_graph_def = tf.saved_model.loader.load(sess, ['MODEL_TRAINING'], model_path)

然后获取signature:

signature = meta_graph_def.signature_def

根据signature及定义的输入输出名称获取对应tensor名:

in_tensor_name = signature['my_signature'].inputs['input0'].name
out_tensor_name = signature['my_signature'].outputs['output0'].name

然后可以使用get_tensor_by_name方法获取对应的tensor,进行预测。
完整的代码:github saved_model

3.simple_save

如果不想手动构建builder,tensorflow也提供了一个simple_save方法,可以十分简洁的完成模型保存:

 tf.saved_model.simple_save(sess, model_path, inputs={'input0': x_input}, outputs={'output0': conv})

load模型方法及完整代码:github simple_save

你可能感兴趣的:(TensorFlow)