tensorflow-模型保存和加载(二)

tensorflow-模型保存和加载(二)

TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式。

格式 简介
Checkpoint 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
GraphDef 用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
SavedModel 使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
FrozenGraph 使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
TFLite 基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异。

上节tensorflow-模型保存和加载(一)我们讲了如何以Checkpoint形式保存和加载模型,产出了保存权重的.data文件和保存元模型的.meta文件。
现在我们想在生产环境中使用模型提供服务,我们想把模型和权重打包到一个文件里,便于存储、升级和版本管理。

用于TensorFlow Serving服务

使用SavedModel格式导出模型,就可以直接使用通用的TensorFlow Serving服务,模型导出即可上线不需要改任何代码。不同的模型导出时只要指定输入和输出的signature即可。

    model_signature = signature_def_utils.build_signature_def(
        inputs={
            # graph.get_operation_by_name("operation").outputs[0]
            "input": utils.build_tensor_info(graph.get_tensor_by_name("inference_input:0"))
        },
        outputs={
            "output": utils.build_tensor_info(graph.get_tensor_by_name("inference_output:0")),
            # "prediction": utils.build_tensor_info(graph.get_tensor_by_name("inference_op:0")),
        },
        method_name=signature_constants.PREDICT_METHOD_NAME)

    try:
        builder = saved_model_builder.SavedModelBuilder(export_version_path)
        # Adds the current meta graph to the SavedModel and saves variables.
        builder.add_meta_graph_and_variables(
            sess,
            [tag_constants.SERVING],
            clear_devices=True,  # Set to true if the device info on the default graph should be cleared.
            signature_def_map={
                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    model_signature,
            },
            # tf.group():Create an op that groups multiple operations.
            # When this op finishes, all ops in input have finished. This op has no output.
            # Returns:An Operation that executes all its inputs.
            # Legacy support for op or group of ops to execute after the restore op upon a load.
            legacy_init_op=tf.group(
                # Can not use gloabal and local initial, then expotred model will inference with read file,
                # but not placeholder.
                # tf.global_variables_initializer(),
                # tf.local_variables_initializer(),
                tf.tables_initializer(),
                name="legacy_init_op"))

        builder.save()

使用TensorFlow的API导出SavedModel模型后,可以检查模型的目录结构如下:
.
└── 1
  ├── saved_model.pb
  └── variables
    ├── variables.data-00000-of-00002
    ├── variables.data-00001-of-00002
    └── variables.index

用于其他语言解析模型提供服务

通过 graph_util.convert_variables_to_constants 将模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util


def freeze_graph(model_folder, output_graph):
    # 检查目录下ckpt文件状态是否可用
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    output_node_names = "inference_input, inference_output"  # NOTE: Change here

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            output_node_names.split(",")  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

        for op in graph.get_operations():
            print(op.name, op.values())

        print("output_graph:", output_graph)
        print("all done")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Tensorflow graph freezer Converts .ckpt to .pb file",
                                     prefix_chars='-')
    parser.add_argument("model_folder", type=str, help="input ckpt model dir", default="./checkpoint/")
    parser.add_argument("model_name", type=str, help="output model name", default="./model.pb")

    args = parser.parse_args()
    print(args, "\n")

    freeze_graph(args.model_folder, args.model_name)

参考
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

模型上线

部署在线服务

使用原生gRPC API或者RESTful API: https://github.com/tensorflow/serving。
使用HTTP接口可参考 https://github.com/tobegit3hub/simple_tensorflow_serving。

部署离线设备

部署到Android可参考 https://medium.com/@tobe_ml/all-tensorflow-models-can-be-embedded-into-mobile-devices-1932e80579e5 。

部署到iOS可参考 https://zhuanlan.zhihu.com/p/33715219 。

Reference:

https://www.tensorflow.org/deploy/
https://www.tensorflow.org/mobile/prepare_models

你可能感兴趣的:(tensorflow-模型保存和加载(二))