Tensorflow使用SavedModel格式模型

Saved Model是Tensorflow支持的一种保存模型的方式,在使用TF-Serving的时候需要使用这种格式的模型文件。
下面以mnist手写数字识别为例,介绍一下这种格式的save和restore以及如何使用。

  • 保存Saved Model格式模型
    这个可以参照Tensorflow Serving(http://github.com/tensorflow/serving.git)自带的mnist训练的例子,具体在./tensorflow_serving/example/mnist_saved_model.py文件中,大家可以执行一下这个脚本就可以生成mnist模型,并且格式是Saved Model。
    脚本中最后的代码就是如何保存Saved Model格式的模型文件,如下:
# Export model
  # WARNING(break-tutorial-inline-code): The following code snippet is
  # in-lined in tutorials, please update tutorial documents accordingly
  # whenever code changes.
  export_path_base = sys.argv[-1]
  export_path = os.path.join(
      tf.compat.as_bytes(export_path_base),
      tf.compat.as_bytes(str(FLAGS.model_version)))
  print('Exporting trained model to', export_path)
  builder = tf.saved_model.builder.SavedModelBuilder(export_path)

  # Build the signature_def_map.
  classification_inputs = tf.saved_model.utils.build_tensor_info(
      serialized_tf_example)
  classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
      prediction_classes)
  classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values)

  classification_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={
              tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                  classification_inputs
          },
          outputs={
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                  classification_outputs_classes,
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
                  classification_outputs_scores
          },
          method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))

  tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

  prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'images': tensor_info_x},
          outputs={'scores': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

  legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          'predict_images':
              prediction_signature,
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              classification_signature,
      },
      legacy_init_op=legacy_init_op)

  builder.save()
  • 恢复Saved Model模型并推理使用
    这里给出一段示例代码:
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
import random
from PIL import Image
import sys

signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'images'
output_key = 'scores'


graph = tf.Graph()

export_dir = "./models/mnist/1/"

image_content = []

#for i in range(0, 784):
#       image_content.append(random.random())

img = Image.open(sys.argv[1])
pix = img.load()
width = img.width
height = img.height

print "Image: width = %d, height = %d, mode = %s" %(width, height, img.mode)

for y in range(0, height):
        for x in range(0, width):
                if img.mode == 'P' or img.mode == 'L':
                        print "%3d" % (pix[x, y]),
                        image_content.append(pix[x, y]/255.0)
                elif img.mode == 'RGB':
                        r, g, b = pix[x, y]
                        gray = (0.3 * r) + (0.59 * g) + (0.11 * b)
                        print "%3d" % (int(gray)),
                        image_content.append(gray/255.0)
                else:
                        print "unsupported mode"
        print ""


with tf.Session(graph = graph) as sess:
        meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_dir)
        signature = meta_graph_def.signature_def
        x_tensor_name = signature[signature_key].inputs[input_key].name
        y_tensor_name = signature[signature_key].outputs[output_key].name

        x = sess.graph.get_tensor_by_name(x_tensor_name)
        y = sess.graph.get_tensor_by_name(y_tensor_name)



        y_out = sess.run(y, feed_dict = {x: [image_content]})

        print '---------- inference results ----------------'
        print y_out

选择一张MNIST手写图片进行测试,如下效果:


Tensorflow使用SavedModel格式模型_第1张图片
image.png

你可能感兴趣的:(Tensorflow使用SavedModel格式模型)