TensorFlow wide and deep模型保存和推理

模型代码

https://github.com/NVIDIA/DeepLearningExamples.git

DeepLearningExamples/TensorFlow/Recommendation/WideAndDeep/

参考代码库说明下载数据集和进行训练得到ckpt

保存saved model

(在task.py修改)

# infer input shapes and types from feature_columns as a parse_example_spec
parse_example_spec = tf.feature_column.make_parse_example_spec(deep_columns + wide_columns)
print("parse_example_spec:", parse_example_spec)

# expose serialized Example protobuf string input and parse to feature tensors
# with a serving_input_receiver_fn
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(parse_example_spec)
print("serving_input_receiver_fn:", serving_input_receiver_fn)

# export in saved_model format
estimator.export_saved_model(export_dir_base='./', serving_input_receiver_fn=serving_input_receiver_fn)        

注意:feature columns里面tf.feature_column创建里面tf.int32改成tf.int64,否则保存模型时出现dtype错误。

查看saved model输入输出信息

模型保存在文件夹1633915266 (示例)

saved_model_cli show --dir 1633915266 --tag_set serve --signature_def serving_default
saved_model_cli show --dir saved_model --tag_set serve --all

The given SavedModel SignatureDef contains the following input(s):
  inputs['inputs'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: input_example_tensor:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['classes'] tensor_info:
      dtype: DT_STRING
      shape: (-1, 2)
      name: head/Tile:0
  outputs['scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 2)
      name: head/predictions/probabilities:0
Method name is: tensorflow/serving/classify

saved model 转frozen pb

from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model import tag_constants

input_saved_model_dir = "./1633915266/"
output_node_names = "head/Tile,head/predictions/probabilities"

input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = False
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tag_constants.SERVING
output_graph_filename='frozen_graph.pb'

freeze_graph.freeze_graph(input_graph_filename,
  input_saver_def_path,
  input_binary,
  checkpoint_path,
  output_node_names,
  restore_op_name,
  filename_tensor_name,
  output_graph_filename,
  clear_devices,
  "", "", "",
  input_meta_graph,
  input_saved_model_dir,
  saved_model_tags)

saved model推理

难点在于构造输入,这里给出基于estimator input_fn的简单构造example。

(在task.py修改)

    parse_example_ids=[
        'doc_event_id',
        'doc_id',
        'doc_event_source_id',
        'event_geo_location',
        'event_country_state',
        'doc_event_publisher_id',
        'event_country',
        'event_hour',
        'event_platform',
        'traffic_source',
        'event_weekend',
        'user_has_already_viewed_doc',
        'doc_event_entity_id',
        'doc_event_topic_id',
        'doc_event_category_id',
        'pop_document_id_conf',
        'pop_publisher_id_conf',
        'pop_source_id_conf',
        'pop_entity_id_conf',
        'pop_topic_id_conf',
        'pop_category_id_conf',
        'pop_document_id_log_01scaled',
        'pop_publisher_id_log_01scaled',
        'pop_source_id_log_01scaled',
        'pop_entity_id_log_01scaled',
        'pop_topic_id_log_01scaled',
        'pop_category_id_log_01scaled',
        'user_views_log_01scaled',
        'doc_views_log_01scaled',
        'doc_event_days_since_published_log_01scaled',
        'doc_event_hour_log_01scaled',
        'ad_id',
        'doc_ad_source_id',
        'ad_advertiser',
        'doc_ad_publisher_id',
        'doc_ad_topic_id',
        'doc_ad_entity_id',
        'doc_ad_category_id',
        'pop_ad_id_conf',
        'user_doc_ad_sim_categories_conf',
        'user_doc_ad_sim_topics_conf',
        'pop_advertiser_id_conf',
        'pop_campain_id_conf_multipl_log_01scaled',
        'pop_ad_id_log_01scaled',
        'pop_advertiser_id_log_01scaled',
        'pop_campain_id_log_01scaled',
        'user_doc_ad_sim_categories_log_01scaled',
        'user_doc_ad_sim_topics_log_01scaled',
        'user_doc_ad_sim_entities_log_01scaled',
        'doc_event_doc_ad_sim_categories_log_01scaled',
        'doc_event_doc_ad_sim_topics_log_01scaled',
        'doc_event_doc_ad_sim_entities_log_01scaled',
        'ad_views_log_01scaled',
        'doc_ad_days_since_published_log_01scaled',
    ]


    def _float_feature(tensor):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=tensor))


    def _int64_feature(tensor):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=tensor))

    warmup_time = 100
    measure_time = 100
    eval_batch = 4

    _dataset = eval_input_fn()
    iterator = _dataset.make_one_shot_iterator()
    one_element = iterator.get_next()

    with tf.Session() as sess:
        data = sess.run(one_element)
        # print("data:", data)

    exported_path = "./1633937564"
    predictor = tf.contrib.predictor.from_saved_model(exported_path)

    warmup_time = 100
    measure_time = 100
    eval_batch = 4

    model_path = "./frozen_graph.pb"
    input_names = ["input_example_tensor"]
    output_names = ["head/Tile", "head/predictions/probabilities"]

    graph = tf.Graph()
    with graph.as_default():
        with tf.gfile.FastGFile(model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
        inputs_tf = [graph.get_tensor_by_name(name + ":0") for name in input_names]
        outputs_tf = [graph.get_tensor_by_name(name + ":0") for name in output_names]

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        data0 = data[0]
        feature_list = [{} for i in range(eval_batch)]
        for key in data0:
            if key in parse_example_ids:
                tensor = data0[key]  # this is a batch of data
                for i in range(eval_batch):
                    tensor0 = tensor[i]
                    dtype = str(tensor.dtype)
                    if dtype == "float32":
                        feature_list[i][key] = _float_feature(tensor0)
                    elif dtype == "int64":
                        feature_list[i][key] = _int64_feature(tensor0)
                    else:
                        raise ValueError("dtype not supported")

        # print("feature_list:", feature_list)
        examples = [tf.train.Example(features=tf.train.Features(feature=feature)) for feature in feature_list]
        example_protos = [example.SerializeToString() for example in examples]
        # print("examples:", examples)
        # print("example_protos:", example_protos)

        feed_dict = {"inputs": example_protos}

        for i in range(warmup_time):
            output_dict = predictor(feed_dict)

        time1 = time.time()
        for i in range(measure_time):
            output_dict = predictor(feed_dict)
        time2 = time.time()
        print("avg time:", (time2 - time1) / measure_time)
        print("output_dict:", output_dict)

parse_example_ids来自于上面打印的parse_example_spec

frozen pb推理

(在task.py修改)

    parse_example_ids = [
        'doc_event_id',
        'doc_id',
        'doc_event_source_id',
        'event_geo_location',
        'event_country_state',
        'doc_event_publisher_id',
        'event_country',
        'event_hour',
        'event_platform',
        'traffic_source',
        'event_weekend',
        'user_has_already_viewed_doc',
        'doc_event_entity_id',
        'doc_event_topic_id',
        'doc_event_category_id',
        'pop_document_id_conf',
        'pop_publisher_id_conf',
        'pop_source_id_conf',
        'pop_entity_id_conf',
        'pop_topic_id_conf',
        'pop_category_id_conf',
        'pop_document_id_log_01scaled',
        'pop_publisher_id_log_01scaled',
        'pop_source_id_log_01scaled',
        'pop_entity_id_log_01scaled',
        'pop_topic_id_log_01scaled',
        'pop_category_id_log_01scaled',
        'user_views_log_01scaled',
        'doc_views_log_01scaled',
        'doc_event_days_since_published_log_01scaled',
        'doc_event_hour_log_01scaled',
        'ad_id',
        'doc_ad_source_id',
        'ad_advertiser',
        'doc_ad_publisher_id',
        'doc_ad_topic_id',
        'doc_ad_entity_id',
        'doc_ad_category_id',
        'pop_ad_id_conf',
        'user_doc_ad_sim_categories_conf',
        'user_doc_ad_sim_topics_conf',
        'pop_advertiser_id_conf',
        'pop_campain_id_conf_multipl_log_01scaled',
        'pop_ad_id_log_01scaled',
        'pop_advertiser_id_log_01scaled',
        'pop_campain_id_log_01scaled',
        'user_doc_ad_sim_categories_log_01scaled',
        'user_doc_ad_sim_topics_log_01scaled',
        'user_doc_ad_sim_entities_log_01scaled',
        'doc_event_doc_ad_sim_categories_log_01scaled',
        'doc_event_doc_ad_sim_topics_log_01scaled',
        'doc_event_doc_ad_sim_entities_log_01scaled',
        'ad_views_log_01scaled',
        'doc_ad_days_since_published_log_01scaled',
    ]

    def _float_feature(tensor):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=tensor))

    def _int64_feature(tensor):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=tensor))

    _dataset = eval_input_fn()
    iterator = _dataset.make_one_shot_iterator()
    one_element = iterator.get_next()

    with tf.Session() as sess:
        data = sess.run(one_element)
        # print("data:", data)

    warmup_time = 100
    measure_time = 100
    eval_batch = 4

    model_path = "./frozen_graph.pb"
    input_names = ["input_example_tensor"]
    output_names = ["head/Tile", "head/predictions/probabilities"]

    graph = tf.Graph()
    with graph.as_default():
        with tf.gfile.FastGFile(model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
        inputs_tf = [graph.get_tensor_by_name(name + ":0") for name in input_names]
        outputs_tf = [graph.get_tensor_by_name(name + ":0") for name in output_names]

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        data0 = data[0]
        feature_list = [{} for i in range(eval_batch)]
        for key in data0:
            if key in parse_example_ids:
                tensor = data0[key]  # this is a batch of data
                for i in range(eval_batch):
                    tensor0 = tensor[i]
                    dtype = str(tensor.dtype)
                    if dtype == "float32":
                        feature_list[i][key] = _float_feature(tensor0)
                    elif dtype == "int64":
                        feature_list[i][key] = _int64_feature(tensor0)
                    else:
                        raise ValueError("dtype not supported")

        # print("feature_list:", feature_list)
        examples = [tf.train.Example(features=tf.train.Features(feature=feature)) for feature in feature_list]
        example_protos = [example.SerializeToString() for example in examples]
        # print("examples:", examples)
        # print("example_protos:", example_protos)

        feed_dict = {inputs_tf[0]: example_protos}

        for i in range(warmup_time):
            output_dict = sess.run(outputs_tf, feed_dict=feed_dict)

        time1 = time.time()
        for i in range(measure_time):
            output_dict = sess.run(outputs_tf, feed_dict=feed_dict)
        time2 = time.time()
        print("avg time:", (time2 - time1) / measure_time)
        print("output_dict:", output_dict)

frozen pb tf-trt优化

import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt

def write_graph_def(graph_def, dump_graph_name):
    if dump_graph_name.endswith("pb"):
        with tf.io.gfile.GFile(dump_graph_name, "wb") as f:
            f.write(graph_def.SerializeToString())
    elif dump_graph_name.endswith("pbtxt"):
        tf.io.write_graph(graph_def, os.path.dirname(dump_graph_name), os.path.basename(dump_graph_name))
    else:
        raise ValueError("dump_graph_name {} is invalid".format(dump_graph_name))

model_path = "./frozen_graph.pb"

with tf.gfile.GFile(model_path, "rb") as f:
    frozen_graph = tf.GraphDef()
    frozen_graph.ParseFromString(f.read())

output_names = [
    "head/Tile",
    "head/predictions/probabilities",
]

converter = trt.TrtGraphConverter(
    input_graph_def=frozen_graph,
    nodes_blacklist=output_names,
    precision_mode="FP32", # fp16 mode
    is_dynamic_op=True)

trt_graph = converter.convert()
write_graph_def(trt_graph, "trt_graph.pb")

参考

https://tensorflow.google.cn/tutorials/load_data/tfrecord
https://www.tensorflow.org/api_docs/python/tf/io/parse_example
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/parse-example

tensorflow中tfrecords格式数据读写小记 tensorflow中tfrecords格式数据读写小记 - 知乎

tensorflow 模型导出总结 tensorflow 模型导出总结 - 知乎

一套灵活的Tensorflow Train/Serving方案 一套灵活的Tensorflow Train/Serving方案 - 知乎

https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html

tensorflow estimator使用SavedModel – d0evi1的博客

你可能感兴趣的:(TensorFlow,tensorflow,deep,wide,推理,frozen)