(1)模型导出(生成符合Server要求的模型文件,主要是用builder加一下signature_def的key和method_name)
自带的mnist模型导出代码在serving/tensorflow_serving/example/mnist_saved_model.py(如下),
# mnist_saved_model.py
import os
import sys
# This is a placeholder for a Google-internal import.
import tensorflow as tf
import mnist_input_data
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
tf.app.flags.DEFINE_integer('training_iteration', 1000,
'number of training iterations.')
tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the model.')
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory.')
FLAGS = tf.app.flags.FLAGS
def main(_):
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
print('Usage: mnist_export.py [--training_iteration=x] '
'[--model_version=y] export_dir')
sys.exit(-1)
if FLAGS.training_iteration <= 0:
print 'Please specify a positive value for training iteration.'
sys.exit(-1)
if FLAGS.model_version <= 0:
print 'Please specify a positive value for version number.'
sys.exit(-1)
# Train model
print 'Training model...'
mnist = mnist_input_data.read_data_sets(FLAGS.work_dir, one_hot=True)
sess = tf.InteractiveSession()
serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32),}
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
x = tf.identity(tf_example['x'], name='x') # use tf.identity() to assign name
y_ = tf.placeholder('float', shape=[None, 10])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.global_variables_initializer())
y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
values, indices = tf.nn.top_k(y, 10)
table = tf.contrib.lookup.index_to_string_table_from_tensor(
tf.constant([str(i) for i in xrange(10)]))
prediction_classes = table.lookup(tf.to_int64(indices))
for _ in range(FLAGS.training_iteration):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
print 'training accuracy %g' % sess.run(
accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print 'Done training!'
# Export model
# WARNING(break-tutorial-inline-code): The following code snippet is
# in-lined in tutorials, please update tutorial documents accordingly
# whenever code changes.
export_path_base = sys.argv[-1]
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# Build the signature_def_map.
classification_inputs = tf.saved_model.utils.build_tensor_info(
serialized_tf_example)
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
prediction_classes)
classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values)
classification_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
tf.saved_model.signature_constants.CLASSIFY_INPUTS:
classification_inputs
},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
classification_outputs_classes,
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
classification_outputs_scores
},
method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'images': tensor_info_x},
outputs={'scores': tensor_info_y},
# method_name(必要)
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
# signature_def(必要)
signature_def_map={
'predict_images':
prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
print 'Done exporting!'
if __name__ == '__main__':
tf.app.run()
之后执行命令:
# 确保没有该目录(也可自定义目录)
rm -rf /tmp/mnist_model
# 编译mnist_saved_model.py,其中mnist_saved_model是在同目录下的BUILD文件中进行配置的名称,打开BUILD看一下就明白了
bazel build //tensorflow_serving/example:mnist_saved_model
# 运行mnist_saved_model.py,加上路径,或者在代码中写入
bazel-bin/tensorflow_serving/example/mnist_saved_model /tmp/mnist_model
# 查看导出的tf-serving模型,在该目录下的‘1’文件夹下,‘1’表示模型版本,如果加入‘2’,启动模型服务,tf-serving会自动选择高版本的模型加载
ls /tmp/mnist_model
如果是自己已经训练好的模型,直接将模型文件方法指定目录下(如:/tmp/mnist_model/1/),并确保signature_def属性都有。
(2)启动服务,加载模型
# 首先build一下tf-serving的服务
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
# 启动tf-serving,输入加载的模型路径和端口,模型名称(随意命名)
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=mnist --model_base_path=/tmp/mnist_model/
(3)服务激动成功后,再写客户端
自带的mnist代码在serving/tensorflow_serving/example/mnist_client.py(如下),主要写对ip端口,以及
request.model_spec.name = 'mnist' (对应server模型的model_name)
request.model_spec.signature_name = 'predict_images' (对应server模型的signature_name)
# mnist_client.py
from __future__ import print_function
import sys
import threading
# This is a placeholder for a Google-internal import.
from grpc.beta import implementations
import numpy
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
import mnist_input_data
tf.app.flags.DEFINE_integer('concurrency', 1,
'maximum number of concurrent inference requests')
tf.app.flags.DEFINE_integer('num_tests', 100, 'Number of test images')
tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port')
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory. ')
FLAGS = tf.app.flags.FLAGS
class _ResultCounter(object):
"""Counter for the prediction results."""
def __init__(self, num_tests, concurrency):
self._num_tests = num_tests
self._concurrency = concurrency
self._error = 0
self._done = 0
self._active = 0
self._condition = threading.Condition()
def inc_error(self):
with self._condition:
self._error += 1
def inc_done(self):
with self._condition:
self._done += 1
self._condition.notify()
def dec_active(self):
with self._condition:
self._active -= 1
self._condition.notify()
def get_error_rate(self):
with self._condition:
while self._done != self._num_tests:
self._condition.wait()
return self._error / float(self._num_tests)
def throttle(self):
with self._condition:
while self._active == self._concurrency:
self._condition.wait()
self._active += 1
def _create_rpc_callback(label, result_counter):
"""Creates RPC callback function.
Args:
label: The correct label for the predicted example.
result_counter: Counter for the prediction result.
Returns:
The callback function.
"""
def _callback(result_future):
"""Callback function.
Calculates the statistics for the prediction result.
Args:
result_future: Result future of the RPC.
"""
exception = result_future.exception()
if exception:
result_counter.inc_error()
print(exception)
else:
sys.stdout.write('.')
sys.stdout.flush()
response = numpy.array(
result_future.result().outputs['scores'].float_val)
prediction = numpy.argmax(response)
if label != prediction:
result_counter.inc_error()
result_counter.inc_done()
result_counter.dec_active()
return _callback
def do_inference(hostport, work_dir, concurrency, num_tests):
"""Tests PredictionService with concurrent requests.
Args:
hostport: Host:port address of the PredictionService.
work_dir: The full path of working directory for test data set.
concurrency: Maximum number of concurrent requests.
num_tests: Number of test images to use.
Returns:
The classification error rate.
Raises:
IOError: An error occurred processing test data set.
"""
test_data_set = mnist_input_data.read_data_sets(work_dir).test
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
for _ in range(num_tests):
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
request.model_spec.signature_name = 'predict_images'
image, label = test_data_set.next_batch(1)
request.inputs['images'].CopyFrom(
tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
result_counter.throttle()
result_future = stub.Predict.future(request, 5.0) # 5 seconds
result_future.add_done_callback(
_create_rpc_callback(label[0], result_counter))
return result_counter.get_error_rate()
def main(_):
if FLAGS.num_tests > 10000:
print('num_tests should not be greater than 10k')
return
if not FLAGS.server:
print('please specify server host:port')
return
error_rate = do_inference(FLAGS.server, FLAGS.work_dir,
FLAGS.concurrency, FLAGS.num_tests)
print('\nInference error rate: %s%%' % (error_rate * 100))
if __name__ == '__main__':
tf.app.run()
剩下的就是build客户端代码,如下
# build客户端python代码(mnist_client即mnist_client.py,同样在BUILD文件中配置,类似makefile)
bazel build //tensorflow_serving/example:mnist_client
# 运行并预测
bazel-bin/tensorflow_serving/example/mnist_client --num_tests=1000 --server=localhost:9000
还有一些就是关于BUILD文件的构建,以及C++和python在传入数据及取出预测结果时的一些API,之后有时间再续。
参考链接:
https://blog.csdn.net/bingningning/article/details/72933932
https://blog.csdn.net/wc781708249/article/details/78606514