记得上几篇博客有提到用tensorflow 中保存模型,然后用tensorflow serving中启动服务再用java调用,实际这样绕了很多,今天发现在java中也能直接加载调用TensorFlow serving中调用的格式,实际在java中也可以直接调用pb文件的模型,前面也提到,这也算是另外一种方式吧
直接看代码:
Python 生成模型代码,可以用在TensorFlow serving 中调用:
import tensorflow as tf
import numpy as np
import os
tf.app.flags.DEFINE_integer('training_iteration', 302,
'number of training iterations.')
tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the model.')
tf.app.flags.DEFINE_string('work_dir', 'model/', 'Working directory.')
FLAGS = tf.app.flags.FLAGS
sess = tf.InteractiveSession()
x = tf.placeholder('float', shape=[None, 3],name="x")
y_ = tf.placeholder('float', shape=[None, 1])
w = tf.get_variable('w', shape=[3, 1], initializer=tf.truncated_normal_initializer)
b = tf.get_variable('b', shape=[1], initializer=tf.zeros_initializer)
sess.run(tf.global_variables_initializer())
y = tf.add(tf.matmul(x, w) , b,name="y")
ms_loss = tf.reduce_mean((y - y_) ** 2)
train_step = tf.train.GradientDescentOptimizer(0.005).minimize(ms_loss)
train_x = np.random.randn(1000, 3)
# let the model learn the equation of y = x1 * 1 + x2 * 2 + x3 * 3
train_y = np.sum(train_x * np.array([1, 2, 3]) + np.random.randn(1000, 3) / 100, axis=1).reshape(-1, 1)
train_loss = []
for i in range(FLAGS.training_iteration):
loss, _ = sess.run([ms_loss, train_step], feed_dict={x: train_x, y_: train_y})
train_loss.append(loss)
export_path_base = FLAGS.work_dir
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(FLAGS.model_version)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': tensor_info_x},
outputs={'output': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'prediction':
prediction_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
print('Training error %g' % loss)
print('Done exporting!')
print('Done training!')
在这里生成的版本号1的模型直接copy到java项目的resource目录中去,看下java的配置依赖:
pom.xml:
org.tensorflow
tensorflow
1.7.0
TensorflowUtils:
import org.tensorflow.SavedModelBundle;
public class TensorflowUtils {
public static SavedModelBundle loadmodel(String modelpath){
SavedModelBundle bundle=SavedModelBundle.load(modelpath,"serve");
return bundle;
}
}
Main.java
import com.xxxxxx.util.TensorflowUtils;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import java.util.Arrays;
public class Model {
SavedModelBundle bundle = null;
public void init(){
String classpath=this.getClass().getResource("/").getPath()+"1" ;
bundle=TensorflowUtils.loadmodel(classpath);
}
public double getResult(float[][] arr){
Tensor tensor=Tensor.create(arr);
Tensor> result= bundle.session().runner().feed("x",tensor).fetch("y").run().get(0);
float[][] resultValues = (float[][])result.copyTo(new float[1][1]);
result.close();
return resultValues[0][0];
}
public static void main(String[] args){
Model model =new Model();
model.init();
float[][] arr=new float[1][3];
arr[0][0]=1f;
arr[0][1]=0.5f;
arr[0][2]=2.0f;
System.out.println(model.getResult(arr));
System.out.println(Arrays.toString("他".getBytes()));
}
}
基本上训练过程中生成一份模型就可以直接在TensorFlow serving上发布使用,也可以在java中直接调用,还是比较一种好实用的方法。
7.721486568450928