本项目代码
import tensorflow as tf
import numpy as np
class MNISTLoader():
def __init__(self):
mnist = tf.keras.datasets.mnist
(self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
# 归一化,增加颜色通道 [60000, 28, 28, 1]
self.x_train = np.expand_dims(self.x_train.astype(np.float32) / 255.0, axis=-1)
# [10000, 28, 28, 1]
self.x_test = np.expand_dims(self.x_test.astype(np.float32) / 255.0, axis=-1)
# 将标签转换为整型
self.y_train = self.y_train.astype(np.int32)
self.y_test = self.y_test.astype(np.int32)
# 获取训练集和测试集的总数
self.x_train_count, self.x_test_count = self.x_train.shape[0], self.x_test.shape[0]
def get_batch(self, batch_size):
# 从0-60000随机选择batch_size个元素
index = np.random.randint(0, np.shape(self.x_train)[0], batch_size)
return self.x_train[index, :], self.y_train[index]
class MLP(tf.keras.Model):
def __init__(self):
super(MLP, self).__init__()
# 将除第一维以外的维度展平
self.flatten = tf.keras.layers.Flatten()
# units 为输出张量的维度
self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10)
@tf.function(input_signature=[tf.TensorSpec([None, 28, 28, 1], tf.float32)])
def call(self, inputs): # [batch_size, 28, 28, 1]
x = self.flatten(inputs)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
epochs = 5
batch_size = 50
learning_rate = 0.001
model = MLP()
data_loader = MNISTLoader()
# 实例化优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
1、从data_loader中随机取一批数据
2、将这批数据送入模型,计算出模型的预测值
3、预测值与真实值比较,计算损失函数
4、计算损失函数关于模型变量的导数
5、将求出的导数值传入优化器中,使用优化器更新模型参数以最小化损失函数
num_batches = int(data_loader.x_train_count // batch_size * epochs)
for batch_index in range(num_batches):
X, Y = data_loader.get_batch(batch_size)
with tf.GradientTape() as tape:
y_pred = model(X)
# 预测值与真实值比较,计算损失函数
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=Y, y_pred=y_pred)
loss = tf.reduce_mean(loss)
print("batch %d: loss %f" % (batch_index, loss.numpy()))
# 计算梯度
grads = tape.gradient(loss, model.variables)
# 自动根据梯度更新参数
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
# 实例化评估器
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.x_test_count // batch_size)
for batch_index in range(num_batches):
# 定义一个batch的开始和结束位置
start_index, end_index = batch_index * batch_size,(batch_index+1) * batch_size
y_pred = model.predict(data_loader.x_test[start_index: end_index])
sparse_categorical_accuracy.update_state(y_true=data_loader.y_test[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())
tf.saved_model.save(model, "D:/file/model/")
assets #模型依赖的外部文件,比如vocab
saved_model.pb #模型的网络结构,可以接受tensor输入,计算完后输出tensor
# saved_model.pb或saved_model.pbtxt是SavedModel协议缓冲区。它将图形定义作为MetaGraphDef协议缓冲区。MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示
variables #模型的参数
saved_model_cli show --dir model_dir_path --all
docker run -p 8501:8501 --mount type=bind,source=/home/linjie/model,target=/models/saved_model -e MODEL_NAME=saved_model -t tensorflow/serving &(RESTful API)
docker run -p 8500:8500 --mount type=bind,source=/home/linjie/model,target=/models/saved_model -e MODEL_NAME=saved_model -t tensorflow/serving &(gRPC)
1、先配置config文件
model_config_list:{
config:{
name:"z_model", # 名字随意
base_path:"/models/ble/z_model", # 一定要用/models/XXXX
model_platform:"tensorflow"
},
config:{
name:"xy_model",
base_path:"/models/ble/xy_model",
model_platform:"tensorflow"
}
}
2、进行多模型部署
docker run -p 8500:8500 -p 8501:8501 --mount type=bind,source=/home/ble/,target=/models/ble -t tensorflow/serving --model_config_file=/models/ble/model.config
# 其中,model_config_file路径也要用/models/XXX,端口8500为gRPC方式调用,端口8501位RESTful API方式调用
3、多模型部署后,请求地址也有稍微不同
原地址:http://192.168.110.100:8501/v1/models/saved_model:predict
现地址:http://192.168.110.100:8501/v1/models/(config中模型的name):predict
以上均在CentOS7虚拟机上进行,所用详细命令暂不给出
// 查看正在运行的容器
docker ps
// 停止容器
docker stop 容器ID
// 查看当前容器状态
service docker status
import tensorflow as tf
import numpy as np
class MNISTLoader():
def __init__(self):
mnist = tf.keras.datasets.mnist
(self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
# 归一化,增加颜色通道 [60000, 28, 28, 1]
self.x_train = np.expand_dims(self.x_train.astype(np.float32) / 255.0, axis=-1)
# [10000, 28, 28, 1]
self.x_test = np.expand_dims(self.x_test.astype(np.float32) / 255.0, axis=-1)
# 将标签转换为整型
self.y_train = self.y_train.astype(np.int32)
self.y_test = self.y_test.astype(np.int32)
# 获取训练集和测试集的总数
self.x_train_count, self.x_test_count = self.x_train.shape[0], self.x_test.shape[0]
def get_batch(self, batch_size):
# 从0-60000随机选择batch_size个元素
index = np.random.randint(0, np.shape(self.x_train)[0], batch_size)
return self.x_train[index, :], self.y_train[index]
import json
import requests
data_loder = MNISTLoader()
data = json.dumps({"instances": data_loder.x_test[0:10].tolist()})
headers = {"content-type": "application/json"}
json_response = requests.post('http://192.168.110.100:8501/v1/models/saved_model:predict',data=data, headers=headers)
pre = np.array(json.loads(json_response.text)['predictions'])
print(np.argmax(pre, axis=-1))
print(data_loder.y_test[0:10])
request.post(url, data=None, json=None, **kwargs)
# 返回响应对象
(gRPC调用)若使用Java作为客户端,则需要编译proto文件
参考地址:
1、https://github.com/junwan01/tensorflow-serve-client
2、https://www.cnblogs.com/ustcwx/p/12768463.html
// 需要注意版本问题,由.proto文件编译出来的java class依赖tensorflow的jar包,可能存在不兼容问题
【Linux】
export SRC=~/Documents/source_code/
mkdir -p $SRC
cd $SRC
git clone [email protected]:tensorflow/serving.git
cd serving
git checkout tags/2.1.0
cd $RSC
git clone [email protected]:tensorflow/tensorflow.git
cd tensorflow
git checkout tags/v2.1.0
【Windows】
// 创建文件夹
mkdir D:/file/source_code
cd D:/file/source_code
// git下载serving
git clone https://github.com/tensorflow/serving
cd serving
git checkout tags/2.1.0
cd D:/file/source_code
// git下载tensorflow
git clone https://github.com/tensorflow/tensorflow
cd tensorflow
git checkout tags/v2.1.0
$ mkdir -p $PROJECT_ROOT/src/main/proto/
$ rsync -arv --prune-empty-dirs --include="*/" --include='*.proto' --exclude='*' $SRC/serving/tensorflow_serving $PROJECT_ROOT/src/main/proto/
$ rsync -arv --prune-empty-dirs --include="*/" --include="tensorflow/core/lib/core/*.proto" --include='tensorflow/core/framework/*.proto' --include="tensorflow/core/example/*.proto" --include="tensorflow/core/protobuf/*.proto" --exclude='*' $SRC/tensorflow/tensorflow $PROJECT_ROOT/src/main/proto/
// 因未安装rsync,所以直接拷贝前人准备好的proto文件放置java工程中
参考地址:https://github.com/junwan01/tensorflow-serve-client/tree/master/src/main/proto
<properties>
<grpc.version>1.20.0grpc.version>
properties>
<dependencies>
<dependency>
<groupId>io.grpcgroupId>
<artifactId>grpc-protobufartifactId>
<version>${grpc.version}version>
dependency>
<dependency>
<groupId>io.grpcgroupId>
<artifactId>grpc-stubartifactId>
<version>${grpc.version}version>
dependency>
<dependency>
<groupId>io.grpcgroupId>
<artifactId>grpc-netty-shadedartifactId>
<version>${grpc.version}version>
dependency>
<dependency>
<groupId>com.google.api.grpcgroupId>
<artifactId>proto-google-common-protosartifactId>
<version>1.0.0version>
dependency>
dependencies>
brew install protobuf(Windows实测不行,没有brew命令)
<build>
<extensions>
<extension>
<groupId>kr.motd.mavengroupId>
<artifactId>os-maven-pluginartifactId>
<version>1.6.2version>
extension>
extensions>
<plugins>
<plugin>
<groupId>org.xolstice.maven.pluginsgroupId>
<artifactId>protobuf-maven-pluginartifactId>
<version>0.6.1version>
<executions>
<execution>
<goals>
<goal>compilegoal>
<goal>compile-customgoal>
goals>
execution>
executions>
<configuration>
<checkStaleness>truecheckStaleness>
<protocArtifact>com.google.protobuf:protoc:3.6.1:exe:${os.detected.classifier}protocArtifact>
<pluginId>grpc-javapluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}pluginArtifact>
configuration>
plugin>
plugins>
build>
// 在当前工程根路径下执行命令
mvn protobuf:compile(执行报错,版本问题未解决)
编译完成之后,在$PROJECT_ROOT/src/main/resources下会增加一个new_old的文件夹,将里面的./org/tensorflow 和 ./tensorflow/serving 两个文件夹移动至PROJECT_ROOT/src/main/java下即可
执行失败,所以直接拷贝前人的文件至工程路径。
参考链接:https://github.com/junwan01/tensorflow-serve-client/tree/master/target/generated-sources/protobuf
手动编译相较前者麻烦些,但是可以编译出静态的代码集成至工程中,而不是每次运行都动态生成(未尝试)
// grpc-java repo代码地址:https://github.com/grpc/grpc-java
$ cd $SRC
$ git clone https://github.com/grpc/grpc-java.git
Cloning into 'grpc-java'...
remote: Enumerating objects: 166, done.
remote: Counting objects: 100% (166/166), done.
remote: Compressing objects: 100% (121/121), done.
remote: Total 84096 (delta 66), reused 92 (delta 25), pack-reused 83930
Receiving objects: 100% (84096/84096), 31.18 MiB | 23.14 MiB/s, done.
Resolving deltas: 100% (38843/38843), done.
$ cd grpc-java/compiler/
$ ../gradlew java_pluginExecutable
$ ls -l build/exe/java_plugin/protoc-gen-grpc-java
// 运行shell脚本,编译protobuf文件
export SRC=~/code/TFS_source/
export PROJECT_ROOT=~/java/JavaClient/
cd $PROJECT_ROOT/src/main/proto/
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/example/*.proto
# append by wangxiao
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/core/logging.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/stream_executor/dnn.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/apis/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/config/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/util/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow_serving/sources/storage_path/*.proto
# the following 3 cmds will generate extra *Grpc.java stub source files in addition to the regular protobuf Java source files.
# The output grpc-java files are put in the same directory as the regular java source files.
# note the --plugin option uses the grpc-java plugin file we created in step 1.
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/protobuf/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/lib/core/*.proto
protoc --java_out $PROJECT_ROOT/src/main/java --proto_path ./ ./tensorflow/core/framework/*.proto
protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/prediction_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/model_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
protoc --grpc-java_out $PROJECT_ROOT/src/main/java --java_out $PROJECT_ROOT/src/main/java --proto_path ./ tensorflow_serving/apis/session_service.proto --plugin=protoc-gen-grpc-java=$SRC/grpc-java/compiler/build/exe/java_plugin/protoc-gen-grpc-java
运行正常的情况下,$PROJECT_ROOT/src/main/java/ 文件夹里应该新增了/org/tensorflow 和 /tensorflow/serving 两个文件夹,至此,编译结束!
1、参考源码
package client;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import tensorflow.serving.Model;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import java.util.ArrayList;
import java.util.List;
public class FastTextTFSClient {
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
String host = "127.0.0.1";
int port = 8500;
// the model's name.
String modelName = "fastText";
int seqLen = 50;
// assume this model takes input of free text, and make some sentiment prediction.
List<Integer> intData = new ArrayList<Integer>();
for(int i=0; i < seqLen; i++){
intData.add(i);
}
// create a channel for gRPC
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
// create a modelspec
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setSignatureName("fastText_sig_def");
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
// create the input TensorProto and request
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_INT32);
for (Integer intDatum : intData) {
tensorProtoBuilder.addIntVal(intDatum);
}
// build input TensorProto shape
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(seqLen));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
TensorProto tp = tensorProtoBuilder.build();
builder.putInputs("input_x", tp);
Predict.PredictRequest request = builder.build();
// get response
Predict.PredictResponse response = stub.predict(request);
}
}
2、我的代码
1、以MNIST数据集为例,在java客户端进行调用
2、需要编写load(读取)MNIST数据集的代码
<dependency>
<groupId>org.tensorflowgroupId>
<artifactId>tensorflowartifactId>
<version>1.7.0version>
dependency>
// Mnist.java
package data;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;
/**
* @Title:Mnist
* @Package:com.linjie.client
* @Description:
* @author:done
* @date:2021/8/12 21:36
*/
public class Mnist {
public static class Data {
public byte[] data;
public int label;
public float[] input;
public float[] output;
}
public static void main(String[] args) throws Exception {
Mnist mnist = new Mnist();
mnist.load();
System.out.println("Data loaded.");
Random rand = new Random(System.nanoTime());
for (int i = 0; i < 20; i++) {
int idx = rand.nextInt(60000);
Data d = mnist.getTrainingData(idx);
BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
}
}
File output = new File(i + "_" + d.label + ".png");
if (!output.exists()) {
output.createNewFile();
}
ImageIO.write(img, "png", output);
}
}
static int toRgb(byte bb) {
int b = (255 - (0xff & bb));
return (b << 16 | b << 8 | b) & 0xffffff;
}
Data[] trainingSet;
Data[] testSet;
public void shuffle() {
Random rand = new Random();
for (int i = 0; i < trainingSet.length; i++) {
int x = rand.nextInt(trainingSet.length);
Data d = trainingSet[i];
trainingSet[i] = trainingSet[x];
trainingSet[x] = trainingSet[i];
}
}
public Data getTrainingData(int idx) {
return trainingSet[idx];
}
public Data[] getTrainingSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(trainingSet, start, ret, 0, count);
return ret;
}
public Data getTestData(int idx) {
return testSet[idx];
}
public Data[] getTestSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(testSet, start, ret, 0, count);
return ret;
}
public void load() {
trainingSet = load("D:\\dowl\\mnist_dataset\\mnist_dataset\\train-images-idx3-ubyte.gz", "D:\\dowl\\mnist_dataset\\mnist_dataset\\train-labels-idx1-ubyte.gz");
testSet = load("D:\\dowl\\mnist_dataset\\mnist_dataset\\t10k-images-idx3-ubyte.gz", "D:\\dowl\\mnist_dataset\\mnist_dataset\\t10k-labels-idx1-ubyte.gz");
if (trainingSet.length != 60000 || testSet.length != 10000) {
throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
}
}
private Data[] load(String imgFile, String labelFile) {
byte[][] images = loadImages(imgFile);
byte[] labels = loadLabels(labelFile);
if (images.length != labels.length) {
throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
}
int len = images.length;
Data[] data = new Data[len];
for (int i = 0; i < len; i++) {
data[i] = new Data();
data[i].data = images[i];
data[i].label = 0xff & labels[i];
data[i].input = dataToInput(images[i]);
data[i].output = labelToOutput(labels[i]);
}
return data;
}
private float[] labelToOutput(byte label) {
float[] o = new float[10];
o[label] = 1;
return o;
}
private float[] dataToInput(byte[] b) {
float[] d = new float[b.length];
for (int i = 0; i < b.length; i++) {
d[i] = (b[i] & 0xff) / 255.f;
}
return d;
}
private byte[][] loadImages(String imgFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
int magic = in.readInt();
if (magic != 0x00000803) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
int rows = in.readInt();
int cols = in.readInt();
if (rows != 28 || cols != 28) {
throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
}
byte[][] data = new byte[count][rows * cols];
for (int i = 0; i < count; i++) {
in.readFully(data[i]);
}
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + imgFile, ex);
}
}
private byte[] loadLabels(String labelFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
int magic = in.readInt();
if (magic != 0x00000801) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
byte[] data = new byte[count];
in.readFully(data);
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + labelFile, ex);
}
}
}
// TestClient.java
package client;
import data.Mnist;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
/**
* @Title:TestClient
* @Package:client
* @Description:
* @author:done
* @date:2021/8/16 21:49
*/
public class TestClient {
public static void main(String[] args) throws Exception {
String host = "192.168.110.100";
int port = 8500;
// the model's name.
// 使用命令查看模型 saved_model_cli show --dir model_dir_path --all
String modelName = "saved_model";
// Mnist实例化
Mnist mnist = new Mnist();
mnist.load();
/************************** 单张图片分类 *******************************/
Mnist.Data testData = mnist.getTestData(0); // 获取第一张图片
float[] x = testData.input; // 获取单张图片的输入张量
int seqLen = 784; // 输入大小
System.out.println("data[0]的真实标签为:" + testData.label);
/************************** 单张图片分类 *******************************/
/************************** 多张图片分类 *******************************/
// ArrayList X = new ArrayList();
// Mnist.Data[] data = mnist.getTestSlice(0, 10);
// int seqLen = 784 * data.length;
// System.out.print("data[0-10]真实标签为:");
// for (int i=0; i
// X.add(data[i].input);
// System.out.print(data[i].label + " ");
// }
/************************** 多张图片分类 *******************************/
// create a channel for gRPC
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
// create a modelspec
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setSignatureName("serving_default");
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
// create the input TensorProto and request
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
/************************** 单张图片分类 *******************************/
for (Float intDatum : x) {
// 添加输入
tensorProtoBuilder.addFloatVal(intDatum);
}
/************************** 单张图片分类 *******************************/
/************************** 多张图片分类 *******************************/
// for (float[] temp: X) {
// float[] input = temp;
// for (Float intDatum : input) {
// tensorProtoBuilder.addFloatVal(intDatum);
// }
// }
/************************** 多张图片分类 *******************************/
// build input TensorProto shape
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(seqLen));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
TensorProto tp = tensorProtoBuilder.build();
builder.putInputs("args_0", tp); // 输入签名 args_0
Predict.PredictRequest request = builder.build();
System.out.println("******************* 请求响应 *******************");
// get response
Predict.PredictResponse response = stub.predict(request);
// 获取分类概率列表
List<Float> pro = response.getOutputsMap().get("output_1").getFloatValList(); // 输出签名output_1
// 获取分类结果
int pre_y = pro.indexOf(pro.stream().max((o1, o2) -> o1.compareTo(o2)).get());
System.out.println("data[0]的分类结果为:" + pre_y);
}
static private byte[] loadTensorflowModel(String path){
try {
return Files.readAllBytes(Paths.get(path));
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
static private Tensor<Double> covertArrayToTensor(double inputs[]){
return Tensors.create(inputs);
}
}