安装依赖包,
sudo apt-get update && sudo apt-get install -y \
build-essential \
curl \
libcurl3-dev \
git \
libfreetype6-dev \
libpng12-dev \
libzmq3-dev \
pkg-config \
python-dev \
python-numpy \
python-pip \
software-properties-common \
swig \
zip \
zlib1g-dev
安装tensorflow-serving-api,
pip install tensorflow-serving-api
安装server,
sudo apt-get update && sudo apt-get install tensorflow-model-server
设置保存模型路径,模型版本,
# Export inference model.
output_dir='pix2pix_model'
model_version=1
output_path = os.path.join(
tf.compat.as_bytes(output_dir),
tf.compat.as_bytes(str(model_version)))
print('Exporting trained model to', output_path)
builder = tf.saved_model.builder.SavedModelBuilder(output_path)
使用tf.saved_model.utils.build_tensor_info,将模型输入,输出转换为server变量形式,并保存
image_size=512
images = tf.placeholder(tf.float32, [None, image_size, image_size,3])#模型输入
model=pix2pix()
# Run inference.
outputs = model.sampler(images)#模型输出
saver = tf.train.Saver()
saver.restore(sess, 'checkpoint-0')#加载已经训练好的模型参数
inputs_tensor_info = tf.saved_model.utils.build_tensor_info(images)
outputs_tensor_info = tf.saved_model.utils.build_tensor_info(outputs)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'images': inputs_tensor_info},
outputs={
'outputs': outputs_tensor_info,
},
method_name=tf.saved_model.signature_constants.REGRESS_METHOD_NAME
))
builder.save()#保存
由于我这里使用的时回归模型,因此method_name=tf.saved_model.signature_constants.**REGRESS**_METHOD_NAME
若是分类模型,则改为,
method_name=tf.saved_model.signature_constants.**PREDICT**_METHOD_NAME
保存之后,便可以在对应的路径下得到对应版本的模型文件,例如,本文中,保存路径为pix2pix_model,版本为1,则有,
按照上述方法保存模型后,便可以启动客户端,命令如下:
tensorflow_model_server –port=9000 –model_name=pix2pix –model_base_path=/home/detection/tensorflow_serving/example/data/pix2pix_model/
注意,model_base_path必须为绝对路径,否则会出错.
客户端调用model:
python pix2pix_client.py –num_tests=1000 –server=localhost:9000
pix2pix_clinet.py定义如下,
def main(_):
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
data = imread(FLAGS.image)
data = data / 127.5 - 1.
image_size=512
sample=[]
sample.append(data)
sample_image = np.asarray(sample).astype(np.float32)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'pix2pix'
request.model_spec.signature_name =tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
request.inputs['images'].CopyFrom(
tf.contrib.util.make_tensor_proto(sample_image, shape=[1, image_size, image_size,3]))
result_future = stub.Predict.future(request, 5.0) # 5 seconds
response = np.array(
result_future.result().outputs['outputs'].float_val)
out=(response.reshape((512,512,3))+1)*127.5
out= cv2.cvtColor(out.astype(np.float32), cv2.COLOR_BGR2RGB)
cv2.imwrite('data/test_result/' + '1.jpg', out)
完整代码可以参考我的github:https://github.com/qinghua2016/pix2pix_server
c++调用可参考:
https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/inception_client.cc
java调用可以参考:
https://github.com/foxgem/how-to/blob/master/tensorflow/clients/src/main/java/foxgem/Launcher.java