tensorflow serving (二):发布自己的服务

https://www.jianshu.com/p/d673c9507988
通过简单运行了官网例子,对tensorflow serving有了大致的了解,但是怎么把自己的模型发布成服务呢?现在通过一个小例子来学习下。

0. 介绍

这里介绍两种保存模型的方式,发布服务需要的不再是之前保存的ckpt格式数据,而是export出来的模型或者pb模型。通过这两种方式把模型准备好,之后只需要挂在到指定路径下,就可以起服务了。

1. 1 exporter 模型

把官方的half_plus_two简单修改成了half_plus_ten。
与我们保存ckpt不同,需要调用的接口是:

from tensorflow.contrib.session_bundle import exporter

需要把输入输出给重新定义下,然后再用接口导出。

import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter


def Export():
  export_path = "model/half_plus_ten"
  with tf.Session() as sess:
    # Make model parameters a&b variables instead of constants to
    # exercise the variable reloading mechanisms.
    a = tf.Variable(0.5)
    b = tf.Variable(10.0)

    # Calculate, y = a*x + b
    # here we use a placeholder 'x' which is fed at inference time.
    x = tf.placeholder(tf.float32)
    y = tf.add(tf.multiply(a, x), b)

    # Run an export.
    tf.global_variables_initializer().run()
    export = exporter.Exporter(tf.train.Saver())
    export.init(named_graph_signatures={
        "inputs": exporter.generic_signature({"x": x}),
        "outputs": exporter.generic_signature({"y": y}),
        "regress": exporter.regression_signature(x, y)
    })
    export.export(export_path, tf.constant(123), sess)


def main(_):
  Export()

if __name__ == "__main__":
    tf.app.run()

保存好的模型看起来很像ckpt,但是再checkpoint里面可以看到,是“export”。 “00000123”这个文件名是自动生成的,我也不知道为什么会刚好是这个数字。


tensorflow serving (二):发布自己的服务_第1张图片
保存好的模型

1.2 保存pb模型

https://www.jianshu.com/p/9221fbf52c55 通过这个教程,我们把模型保存为pb格式。同样把这个模型文件夹挂在到docker相应的目录下。

tensorflow serving (二):发布自己的服务_第2张图片
保存为pb模型

2. 通过docker起服务

要指定端口,挂载目录,docker才能访问这个模型,挂在的目录得是绝对路径。

  1. export之后的模型挂载。
docker run -t --rm -p 8501:8501 \
   -v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \
   -e MODEL_NAME=half_plus_ten \
   tensorflow/serving
  1. pb模型需要修改挂载路径,可以重新给模型起名字,这里还是用上面的名字“half_plus_ten"。
docker run -t --rm -p 8501:8501 \
   -v "$(pwd)/pb_model:/models/half_plus_ten" \      
   -e MODEL_NAME=half_plus_ten \
   tensorflow/serving

3. 测试服务

给它几个值来测试下这个服务。

curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict

能得到half plus ten这个结果!


输出正确

用python代码访问服务

import os

import requests
from time import time

import numpy as np

url = 'http://localhost:8501/v1/models/half_plus_ten:predict'

a = np.array([1,2 ,3,4])

predict_request = '{"instances" : [{"input": %s}]}' % list(a)  # 一定要list才能传输,不然json错误


print("start")
start_time = time()
r = requests.post(url,data=predict_request)
print(r.content)
end_time = time()

Tips:

代码改写自官方例子:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/servables/tensorflow/testdata/export_half_plus_two.py
代码和模型都放在:
https://github.com/xxlxx1/learing_tf_serving

你可能感兴趣的:(tensorflow serving (二):发布自己的服务)