Java调用在python中完成训练的tensorflow模型

众所周知,tensorflow是Google开源的一个深度学习框架。虽然它除了python语言兼容,还提供了c和Java的API,但目前大部分人还是选择在python环境进行深度学习模型的搭建和训练,不过很多时候,我们需要将完成训练之后较为有效的模型进行封装和部署,但是像许多公司最后的部署都是用Java语言实现的,所以,此时我们就需要掌握如何通过Java来调用tensorflow的模型。

保存模型

首先,我们需要在python中将模型导出为pb文件(二进制文件)

#coding=utf-8
import tensorflow as tf
 
 
# 定义图
x = tf.placeholder(tf.float32, name="x")
y = tf.get_variable("y", initializer=10.0)
z = tf.log(x + y, name="z")
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
 
    # 进行一些训练代码,此处省略
    # xxxxxxxxxxxx
 
    # 显示图中的节点
    print([n.name for n in sess.graph.as_graph_def().node])
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names=["z"])  # 模型的输出变量名称
 
    # 保存图为pb文件
    with open('model.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

配置maven依赖

在Java中,我们需要用到以下依赖包:


    org.tensorflow
      libtensorflow_jni
      1.12.0 
	  
	  
	      org.tensorflow
	      tensorflow
	      1.12.0
	  
	  
	      org.tensorflow
	      libtensorflow
	      1.12.0
	  
	  
	      commons-io
	      commons-io
	      2.6


Java调用tensorflow

  1. 加载pb文件,生成计算图;
  2. 将数据放入Buffer,然后生成Tensor变量(其实就是tensorflow中的常量,用于喂给placeholder),或者直接用数组也可以;
  3. 将Tensor变量喂给placeholder节点,然后在session中运行计算得到我们想要的输出节点(即python的sess.run);
  4. 将输出的Tensor变量转化为Java能够进一步计算的数组类型。
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Serializable;
import java.nio.IntBuffer;
import java.util.LinkedList;
import java.util.List;

public class Model{
    Graph graph = new Graph();
    Session session = null;
    private String model_path = "";

    public NERLstm(String model_path) throws FileNotFoundException, IOException {
        this.model_path = model_path;
        // 将pb文件转化为byte数组
        byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(this.model_path));
        // 将byte数组导入为tensorflow的计算图
        graph.importGraphDef(graphBytes);
        // 创建session
        session = new Session(graph);
    }

    public List<float[][]> Predict(Integer[] seq, int seq_len_list) {

        List<float[][]> result = new LinkedList<float[][]>();

        // 将数据放入Buffer
        Long[] data = new Long[seq_len_list];
        IntBuffer buffer = IntBuffer.allocate(Integer.valueOf(seq_len_list));
        for (int i = 0; i < seq.length; i++) {
            buffer.put(seq[i]);
        }

        buffer.flip();
        // 利用Buffer创建Tensor变量
        long[] shape = { 1, seq_len_list }; //Tensor变量的shape 
        Tensor<Integer> word_ids = Tensor.create(shape, buffer);
        long[] shape_seq = { 1 };
        IntBuffer buffer_seq = IntBuffer.allocate(Integer.valueOf(1));
        int seq_len = Integer.valueOf(seq_len_list);
        buffer_seq.put(seq_len);
        buffer_seq.flip();
        Tensor<Integer> sequence_lengths = Tensor.create(shape_seq, buffer_seq);
        Tensor<Float> drop_out = (Tensor<Float>) Tensor.create(1.0f);

        try {
        	// 根据实际情况,feed需要的placeholder变量,fetch输出变量
            List<Tensor<?>> res = session.runner().feed("word_ids", word_ids).feed("sequence_lengths", sequence_lengths)
                    .feed("dropout", drop_out).fetch("logits_output").fetch("transitions").run();
            // 将Tensor变量转化为数组
            float[][][] logits = res.get(0).copyTo(
                    new float[(int) res.get(0).shape()[0]][(int) res.get(0).shape()[1]][(int) res.get(0).shape()[2]]);
            float[][] logits2D = logits[0];

            float[][] transparams = res.get(1)
                    .copyTo(new float[(int) res.get(1).shape()[0]][(int) res.get(1).shape()[1]]);
            result.add(logits2D);
            result.add(transparams);
        } catch (Exception e) {
            System.err.println(e.getMessage());
        }
        return result;
    }
}

欢迎关注同名公众号:“我就算饿死也不做程序员”。
交个朋友,一起交流,一起学习,一起进步。Java调用在python中完成训练的tensorflow模型_第1张图片

你可能感兴趣的:(tensorflow,Java,python,java,tensorflow,深度学习,python)