tensorflow训练好的模型中java调用

最近基于bi-lstm做了一个辱骂识别模型准备部署到线上,之前打算用python 启动一个service 通过http请求来调用,发现公司平台是基于rpc服务的,开发部署起来也较蛋疼,今天下午闲来没事,看到tensorflow中有提供官方例子,通过python中训练好模型,用java来调用,刚刚好摸索了下,动手写了下代码,总算能在java中调用,废话不多说,直接看代码实现情况。


tensorflow版本情况:

 
    
  1. In [1]: import tensorflow as tf
  2. In [2]: tf.__version__
  3. Out[2]: '1.2.1'

java需要1.8的版本



maven依赖:

 
    
  1. org.tensorflowgroupId>
  2. tensorflowartifactId>
  3. 1.2.1version>
  4. dependency>


参考资料:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
http://blog.csdn.net/lyg5623/article/details/72781405


tensorflow训练模型时候要保存的模型参数,主要有是三个,一个是模型输入的tensor大小,一个是dropout参数,一个是模型预测的logits(score/pred_y 表示name_scope下的pred_y)值,也就是y;模型保存为一个二进制文件,可以在java中加载:

 
    
  1. if i%500==0 and i>0:
  2. graph = tf.graph_util.convert_variables_to_constants(session, session.graph_def,
  3. ["keep_prob", "input_x", "score/pred_y"])
  4. tf.train.write_graph(graph, ".", "/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph/graph.db",
  5. as_text=False)


java代码如下,其中gettexttoid方法参考tensorflow中 tensorflow.contrib.keras.preprocessing.sequence.pad_sequences下的实现,用于做文本预测:

 
    
  1. package com.meituan.test;
  2. import java.io.BufferedReader;
  3. import java.io.File;
  4. import java.io.FileInputStream;
  5. import java.io.IOException;
  6. import java.io.InputStreamReader;
  7. import java.nio.ByteBuffer;
  8. import java.nio.ByteOrder;
  9. import java.nio.IntBuffer;
  10. import java.nio.file.Files;
  11. import java.nio.file.Paths;
  12. import java.nio.file.Path;
  13. import java.util.ArrayList;
  14. import java.util.Arrays;
  15. import java.util.Collection;
  16. import java.util.HashMap;
  17. import java.util.List;
  18. import java.util.Map;
  19. import org.apache.commons.io.FileUtils;
  20. import org.apache.commons.lang.StringUtils;
  21. import org.tensorflow.Graph;
  22. import org.tensorflow.Session;
  23. import org.tensorflow.Tensor;
  24. public class TensorflowEx {
  25. private static String path = "/Users/shuubiasahi/Documents/python/credit-tftextclassify-abuse/vocab_cnews.txt";
  26. private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();
  27. static {
  28. try {
  29. BufferedReader buffer = null;
  30. buffer = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
  31. int i=0;
  32. String line=buffer.readLine().trim();
  33. while(line!=null){
  34. word_to_id.put(line, i++);
  35. line=buffer.readLine().trim();
  36. }
  37. buffer.close();
  38. } catch (Exception e) {
  39. }
  40. System.out.println("word_to_id.size is:"+word_to_id.size());
  41. }
  42. public static void main(String[] args) {
  43. byte[] graphDef = readAllBytesOrExit(Paths.get(
  44. "/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph",
  45. "graph.db"));
  46. Graph g = new Graph();
  47. g.importGraphDef(graphDef);
  48. Session sess = new Session(g);
  49. String text="艹你麻痹的垃圾店家,劳资点的香干回锅肉套餐,你他麻痹炒个香干炒肉过来凑数,套餐内所有的东西都没看到,还尼玛口口声声说退款?退你麻痹,留着给你家人买棺材用吧,狗日的东西!";
  50. int[][] arr=gettexttoid(text);
  51. Tensor input = Tensor.create(arr);
  52. Tensor x = Tensor.create(1.0f);
  53. Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", x)
  54. .fetch("score/pred_y").run().get(0);
  55. long[] rshape = result.shape();
  56. /*
  57. * 模型为一个二分类模型,故nlabels=2,模型预测一条数据故batchsize=1
  58. * 预测出来是一个1*2的数组,一条数据有两个概率
  59. *
  60. * */
  61. int nlabels = (int) rshape[1];
  62. int batchSize = (int) rshape[0];
  63. float[][] logits = result.copyTo(new float[batchSize][nlabels]);
  64. System.out.println("辱骂模型识别的概率为:"+logits[0][1]);
  65. System.out.println("sucess");
  66. }
  67. private static byte[] readAllBytesOrExit(Path path) {
  68. try {
  69. return Files.readAllBytes(path);
  70. } catch (IOException e) {
  71. System.err.println("Failed to read [" + path + "]: "
  72. + e.getMessage());
  73. System.exit(1);
  74. }
  75. return null;
  76. }
  77. /*
  78. * 序列默人长度为300
  79. * */
  80. public static int[][] gettexttoid(String text){
  81. int[][] xpad = new int[1][300];
  82. if(StringUtils.isBlank(text)){
  83. return xpad;
  84. }
  85. char[] chs=text.trim().toLowerCase().toCharArray();
  86. List<Integer> list=new ArrayList<Integer>();
  87. for(int i=0;i<chs.length;i++){
  88. String element=Character.toString(chs[i]);
  89. if(word_to_id.containsKey(element)){
  90. list.add(word_to_id.get(element));
  91. }
  92. }
  93. if(list.size()==0){
  94. return xpad;
  95. }
  96. int size = list.size();
  97. Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);
  98. int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
  99. if(size<=300){
  100. System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);
  101. }else{
  102. System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);
  103. }
  104. return xpad;
  105. }
  106. /*
  107. * 自定义长度
  108. * */
  109. public static int[][] gettexttoid(String text,int maxlen){
  110. if(maxlen<1){
  111. throw new IllegalArgumentException("maxlen长度必须大于等于1");
  112. }
  113. int[][] xpad = new int[1][maxlen];
  114. if(StringUtils.isBlank(text)){
  115. return xpad;
  116. }
  117. char[] chs=text.trim().toLowerCase().toCharArray();
  118. List<Integer> list=new ArrayList<Integer>();
  119. for(int i=0;i<chs.length;i++){
  120. String element=Character.toString(chs[i]);
  121. if(word_to_id.containsKey(element)){
  122. list.add(word_to_id.get(element));
  123. }
  124. }
  125. if(list.size()==0){
  126. return xpad;
  127. }
  128. int size = list.size();
  129. Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);
  130. int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
  131. if(size<=maxlen){
  132. System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);
  133. }else{
  134. System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);
  135. }
  136. return xpad;
  137. }
  138. }

结果对比:

java结果:


tensorflow训练好的模型中java调用_第1张图片



python启动的service结果:

tensorflow训练好的模型中java调用_第2张图片







结果一致,下周计划写个java service项目,把模型部署上线。


不过我碰到过问题,在java中做预测,1秒最多只能预测十来条文本,这感觉太慢了,不知道什么原因,我机器用的cpu,不知道是否要用gpu做预测,有知道的告诉我
联系我  xuxu_ge

你可能感兴趣的:(java,机器学习)