Tensorflow模型部署

参考:  Tensorflow 模型线上部署
    构建 TensorFlow Serving Java 客户端

  • docker安装及部署

    • windows下docker安装

    • tf-serving

        下载tensorflow服务并使用docker部署,这一步如果占用C盘空间太大的话,可以使用Hyper-v工具将下载的镜像转到其他盘

      # 在 cmd 中执行以下命令
      docker pull tensorflow/serving   # 下载镜像
      docker run -itd -p 5000:5000 --name tfserving tensorflow/serving   # 运行镜像并指定镜像名
      docker ps  # 查看镜像id  dockerID
      docker cp ./mnist dockerID:/models  # 将 pb 文件夹拷贝到容器中,模型训练见下面
      
      docker exec -it dockerID /bin/bash  # 进入到镜像里面
      tensorflow_model_server --port=5000 --model_name=mnist --model_base_path=/models/mnist  # 容器内运行服务
      
  • 训练模型

      使用官方给出的mnist样例进行训练,改下代码路径就可以,训练得到pb文件如下,并使用 saved_model_cli show --dir ./mnist/1 --all 命令查看节点名称(后面客户端使用),并将模型复制到docker里面docker cp ./mnist dockerID:/models,此处注意文件夹层级

    mnist-pb.png

    sigdef.png
    models.png
  • python端

      仿照官方代码 mnist_clien.py编写预测代码

    import grpc
    import tensorflow as tf
    from tensorflow_serving.apis import predict_pb2
    from tensorflow_serving.apis import prediction_service_pb2_grpc
    
    server = 'localhost:5000'
    
    channel = grpc.insecure_channel(server)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'mnist'
    request.model_spec.signature_name = 'predict_images'
    
    test_data_set = mnist_input_data.read_data_sets('./data').test
    image, label = test_data_set.next_batch(1)
    request.inputs['images'].CopyFrom(tf.make_tensor_proto(image[0], shape=[1, image[0].size]))
    pred = stub.Predict(request, 5.0)
    score = pred.outputs['scores'].float_val
    print(score)
    # [1.6178478001727115e-10, 1.6928293322847278e-15, 1.6151154341059737e-05, 0.000658366538118571, 8.010060947860609e-10, 2.2359495588375466e-08, 3.5608297452131843e-13, 0.9993133544921875, 5.620326870570125e-09, 1.1990837265329901e-05]
    
  • Java端

      Java端流程差不多,主要是编译proto麻烦一些

    • proto安装

        windows下proto的安装参考windows之google protobuf安装与使用,下载proto-3.4.0并解压,注意目录不要有空格,否则后面编译会报错,找到protoc.exe所在路径,我的是D:\protoc-3.4.0-win32\bin

    • pom配置编译proto

        此处主要参考构建 TensorFlow Serving Java 客户端,给出的那个proto文件列表太棒了(未理解为什么是这些文件,对java-grpc不熟悉),仿照其流程,下载tensorflowtensorflow-serving两个项目,复制相应的proto文件出来

      src/main/proto
      ├── tensorflow
      │   └── core
      │       ├── example
      │       │   ├── example.proto
      │       │   └── feature.proto
      │       ├── framework
      │       │   ├── attr_value.proto
      │       │   ├── function.proto
      │       │   ├── graph.proto
      │       │   ├── node_def.proto
      │       │   ├── op_def.proto
      │       │   ├── resource_handle.proto
      │       │   ├── tensor.proto
      │       │   ├── tensor_shape.proto
      │       │   ├── types.proto
      │       │   └── versions.proto
      │       └── protobuf
      │           ├── meta_graph.proto
      │           └── saver.proto
      └── tensorflow_serving
          └── apis
              ├── classification.proto
              ├── get_model_metadata.proto
              ├── inference.proto
              ├── input.proto
              ├── model.proto
              ├── predict.proto
              ├── prediction_service.proto
              └── regression.proto
      

        创建Maven工程,将上面的proto文件放在src/main下面,在pom中添加以下信息,此处额外添加了编译文件的输入及输出目录,否则会报错 protoc did not exit cleanly

      
          
              
                  org.xolstice.maven.plugins
                  protobuf-maven-plugin
                  0.5.0
                  
                      D:\protoc-3.4.0-win32\bin\protoc.exe
                      ${project.basedir}/src/main/proto/
                      ${project.basedir}/src/main/resources/
                  
                  
                      
                          
                              compile
                              compile-custom
                          
                      
                  
              
          
      
      
      
          
              com.google.protobuf
              protobuf-java
              3.11.4
          
          
              io.grpc
              grpc-protobuf
              1.28.0
          
          
              io.grpc
              grpc-stub
              1.28.0
          
          
              io.grpc
              grpc-netty-shaded
              1.28.0
          
      
      

        配置完后,执行maven -> protobuf:compile编译,在resources目录下会生成org及tensorflow两个文件夹,将这两个文件夹复制到src/main/java目录下

      proto.png

    • 预测

        编写java程序进行预测,过程中发现没有tensorflow/serving/PredictionServiceGrpc.java这个文件,试了很多方法都没有编译出来,最后是直接把别人的给复制过来了,PredictionServiceGrpc,拷过来后发现报了@java.lang.Override这几行代码提示有问题,直接将override注释掉

      src/main/java下建表及类,编写预测代码,完整代码如下,运行得预测结果

      package SimpleAdd;
      
      import io.grpc.ManagedChannel;
      import io.grpc.ManagedChannelBuilder;
      import tensorflow.serving.Model;
      import org.tensorflow.framework.DataType;
      import org.tensorflow.framework.TensorProto;
      import org.tensorflow.framework.TensorShapeProto;
      
      import tensorflow.serving.Predict;
      import tensorflow.serving.PredictionServiceGrpc;
      
      
      public class MnistPredict {
          public static void main(String[] args) throws Exception {
              // create a channel for gRPC
              ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 5000).usePlaintext().build();
              PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
      
              // create a modelspec
              Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder();
              modelSpec.setName("mnist");
              modelSpec.setSignatureName("predict_images");
              Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
              request.setModelSpec(modelSpec);
      
              // data shape & load data
              TensorShapeProto.Builder shape = TensorShapeProto.newBuilder();
              shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
              shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(784));
              TensorProto.Builder tensor = TensorProto.newBuilder();
              tensor.setTensorShape(shape);
              tensor.setDtype(DataType.DT_FLOAT);
              for(int i=0; i<784; i++){
                  tensor.addFloatVal(0);
              }
              request.putInputs("images", tensor.build());
              tensor.clear();
      
              // Predict 
              Predict.PredictResponse response = stub.predict(request.build());
              System.out.println(response);
              TensorProto result = response.toBuilder().getOutputsOrThrow("scores");
              System.out.println("predict: " + result.getFloatValList());
              System.out.println("predict: " + response.getOutputsMap().get("scores").getFloatValList());
              // predict: [0.032191742, 0.09621494, 0.06525445, 0.039610844, 0.05699038, 0.46822935, 0.040578533, 0.1338098, 0.009549928, 0.057570033]
          }
      }
      

你可能感兴趣的:(Tensorflow模型部署)