TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式。
格式 | 简介 |
---|---|
Checkpoint | 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。 |
GraphDef | 用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。 |
SavedModel | 使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。 |
FrozenGraph | 使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。 |
TFLite | 基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异。 |
上节tensorflow-模型保存和加载(一)我们讲了如何以Checkpoint形式保存和加载模型,产出了保存权重的.data文件和保存元模型的.meta文件。
现在我们想在生产环境中使用模型提供服务,我们想把模型和权重打包到一个文件里,便于存储、升级和版本管理。
使用SavedModel格式导出模型,就可以直接使用通用的TensorFlow Serving服务,模型导出即可上线不需要改任何代码。不同的模型导出时只要指定输入和输出的signature即可。
model_signature = signature_def_utils.build_signature_def(
inputs={
# graph.get_operation_by_name("operation").outputs[0]
"input": utils.build_tensor_info(graph.get_tensor_by_name("inference_input:0"))
},
outputs={
"output": utils.build_tensor_info(graph.get_tensor_by_name("inference_output:0")),
# "prediction": utils.build_tensor_info(graph.get_tensor_by_name("inference_op:0")),
},
method_name=signature_constants.PREDICT_METHOD_NAME)
try:
builder = saved_model_builder.SavedModelBuilder(export_version_path)
# Adds the current meta graph to the SavedModel and saves variables.
builder.add_meta_graph_and_variables(
sess,
[tag_constants.SERVING],
clear_devices=True, # Set to true if the device info on the default graph should be cleared.
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature,
},
# tf.group():Create an op that groups multiple operations.
# When this op finishes, all ops in input have finished. This op has no output.
# Returns:An Operation that executes all its inputs.
# Legacy support for op or group of ops to execute after the restore op upon a load.
legacy_init_op=tf.group(
# Can not use gloabal and local initial, then expotred model will inference with read file,
# but not placeholder.
# tf.global_variables_initializer(),
# tf.local_variables_initializer(),
tf.tables_initializer(),
name="legacy_init_op"))
builder.save()
使用TensorFlow的API导出SavedModel模型后,可以检查模型的目录结构如下:
.
└── 1
├── saved_model.pb
└── variables
├── variables.data-00000-of-00002
├── variables.data-00001-of-00002
└── variables.index
通过 graph_util.convert_variables_to_constants 将模型持久化
# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util
def freeze_graph(model_folder, output_graph):
# 检查目录下ckpt文件状态是否可用
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep and what part it can dump
output_node_names = "inference_input, inference_output" # NOTE: Change here
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph_def, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
for op in graph.get_operations():
print(op.name, op.values())
print("output_graph:", output_graph)
print("all done")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Tensorflow graph freezer Converts .ckpt to .pb file",
prefix_chars='-')
parser.add_argument("model_folder", type=str, help="input ckpt model dir", default="./checkpoint/")
parser.add_argument("model_name", type=str, help="output model name", default="./model.pb")
args = parser.parse_args()
print(args, "\n")
freeze_graph(args.model_folder, args.model_name)
参考
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py
使用原生gRPC API或者RESTful API: https://github.com/tensorflow/serving。
使用HTTP接口可参考 https://github.com/tobegit3hub/simple_tensorflow_serving。
部署到Android可参考 https://medium.com/@tobe_ml/all-tensorflow-models-can-be-embedded-into-mobile-devices-1932e80579e5 。
部署到iOS可参考 https://zhuanlan.zhihu.com/p/33715219 。
https://www.tensorflow.org/deploy/
https://www.tensorflow.org/mobile/prepare_models