TensorFlow-Serving入门

本文中我们在Mac机器上使用Docker配置TensorFlow-Serving环境,并提供Http预测接口。

安装Docker
brew cask install docker
下载TensorFlow-Serving镜像
docker pull tensorflow/serving
生成SavedModel模型

TensorFlow主要有三种模型格式:CheckPoint(.ckpt),SavedModel,GraphDef(*.pb)。这三种格式之间可以互相转换,CheckPoint格式在训练模型时候每隔几轮保存一次,以方便增量训练。GraphDef格式适用于python、java的tensorflow库进行加载,SavedModel是TensorFlow-Serving要求的格式。我们最开始的是一个xception.pb模型(GraphDef格式),这里需要将其转换为SavedModel格式,代码如下:

import tensorflow.compat.v1 as tf
import time
tf.disable_v2_behavior()
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = '/Work/infra/tensorflow/saved_model'
graph_pb = '/Work/infra/tensorflow/xception.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
    # name="" is important to ensure we don't get spurious prefixing
    tf.import_graph_def(graph_def, name="")
    g = tf.get_default_graph()
    inp = g.get_tensor_by_name("input_1:0") //输入节点名字
    out = g.get_tensor_by_name("output:0") // 输出节点名字

    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \ 
        tf.saved_model.signature_def_utils.predict_signature_def(
            {"in": inp}, {"out": out})

    builder.add_meta_graph_and_variables(sess,
                                         [tag_constants.SERVING],
                                         signature_def_map=sigs)

builder.save()

转换完成后,目录下除了saved_model.pb文件,还多了一个variables文件夹,该文件夹为空,因为从pb文件转换过来的时候全是常量没有变量。

启动TensorFlow-Serving
docker run -p 8501:8501 --name tfserving_testnet  --mount type=bind,source=/Wor
k/infra/tensorflow/xception,target=/models/xception  -e MODEL_NAME=xception -t tensorflow/serving
  • 8051:http端口,前面的为本机的端口,后面的为docker中的端口。
  • name:名字随便起,为了识别docker的container
  • source:模型在本机上的位置目录
  • target:/models/固定,后面的名字随便起,最好和模型名字一致
  • MODEL_NAME:设置Docker中的环境变量,和上面的target名字一致
Http接口
  • 查看TensorFlow-Serving状态:curl http://localhost:8501/v1/models/xception
  • 查看TensorFlow-Serviing模型:curl http://localhost:8501/v1/models/xception/metadata
  • 使用Http请求进行模型预测:curl -d '{"instances": [1,2,3,4,5]}' -X POST http://localhost:8501/v1/models/xception:predict,其中instances的value为模型输入Tensor的字符串形式,矩阵维度需要和Tensor对应。
Python客户端

这里我们使用Python加载图片数据,并发向TensorFlow Http接口进行预测图片质量。Http接口的数据需要是Json格式,代码如下:

SERVER_URL = 'http://localhost:8501/v1/models/xception:predict'
def prediction():
    images = image.load_img("test.jpg", target_size=(480, 480))
    x = image.img_to_array(images)
    x = np.expand_dims(x, axis=0)
    image_np = xception.preprocess_input(x)
    #print str(image_np.tolist())
    predict_request='{"instances":%s}' % str(image_np.tolist())
    #predict_request='{"instances":%s}' % str([[[[1]*3]*480]*480])
    response = requests.post(SERVER_URL, data=predict_request)
    prediction = response.json()
    print(prediction)
        
if __name__ == "__main__":
    prediction()

你可能感兴趣的:(TensorFlow-Serving入门)