Tensorflow之Java部署方案

最近使用Tensorflow的Estimator高阶API进行模型训练,支持保存成checkpoint和saved model格式。
其中saved model可以使用Tensorflow Serving进行部署.

但是目前公司内部还是使用tensorflow的Java API进行部署,所以需要把saved model转化成pb格式,
并用Java进行部署。希望后续能够转移到Tensorflow Serving方式部署,毕竟这是官方推荐的部署方式,而且无论从扩展性,
易用性都能得到保证,而Java API直接就告诉你这是Experimental的了, 还不是稳定版。

saved_model转pb

由于从estimator保存的格式为saved model,所以第一步是需要转成pd格式.
需要注意的是estimator在export_saved_model时,需要指定好对应的placeholder,确保后续有对应的feed的入口。

  1. 使用tf.saved_model.loader.load加载saved model格式,其中需要指定tags,而estimator保存saved model时默认tag为’serve’
  2. 根据output node来freeze graph。如果我们用了feature column的话,还需要把init_all_tables也加入到output node,该问题在github已经被讨论过, 具体见下面的参考资料。
  3. 把freeze之后的graph和constants参数序列化为pb文件
# estimator export feed node
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
  "index": tf.placeholder(tf.int32, [None, field_size], name='index'),
  "value": tf.placeholder(tf.float32, [None, field_size], name='value'),
})

estimator.export_saved_model(saved_model_dir, input_fn)
def save_model2pb(save_model_dir, pb_path, output_node, tags=['serve']):
  with tf.Session(graph=tf.Graph()) as sess, tf.device("/cpu:0"):
    tf.saved_model.loader.load(sess, tags, save_model_dir)
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess=sess,
        input_graph_def=sess.graph_def,
        output_node_names=output_node)

    with tf.gfile.GFile(pb_path, "wb") as f:
      f.write(output_graph_def.SerializeToString())

    print("%d ops in the final graph." % len(output_graph_def.node))

Java部署

java的部署方式与python类似,对于下面代码,我们的input是value和index,output是output_node

  1. 先创建graph和session,以byte的方式读取pb文件
  2. 构建用于feed的tensor. 其中用于feed的node需要在estimator保存成saved model时export出来
  3. 使用runner进行predict,取出predict结果。
public class DeepModel {

    private Graph graph;
    private Session sess;

    public DeepModel(String pbFile) {
        try {
            graph = new Graph();
            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(pbFile));
            graph.importGraphDef(graphBytes);
            sess = new Session(graph);
        } catch (java.io.IOException e) {
            System.out.println("DeepModel initial fail!!!");
        }
    }

    public boolean isNull() {
        return (sess == null) || (graph == null);
    }

    public float[][] predict(int[][] index, float[][] value) {
        Tensor indexTensor = Tensor.create(index);
        Tensor valueTensor = Tensor.create(value);

        Tensor rlt = sess.runner().feed("index", indexTensor).feed("value", valueTensor).fetch("output_node").run().get(0);
        float[][] finalRlt = new float[index.length][1];
        rlt.copyTo(finalRlt);
        return finalRlt;
    }
}

推荐一下自己的项目:search-deeplearning
参考资料:

  1. Using the SavedModel format
  2. freeze_graph not initializing tables

你可能感兴趣的:(DeepLearning)