java调用keras离线训练的图片识别模型进行在线预测

目前深度学习主要使用Python训练自己的模型,其中Keras提供了heigh-level语法,后端可采用Tensorflow或者Theano。

 

但是在实际应用时,大多数公司仍是使用java作为应用系统后台。于是便有了Python离线训练模型,Java调用模型实现在线预测。

 

Java调用Keras模型有两种方案,一种是基于Java的深度学习库DL4J导入Keras模型,另外一种是利用Tensorflow的java接口调用。DL4J目前暂不支持嵌套模型的导入,下面仅介绍第二种方案。

 

想要利用Tensorflow的java接口调用Keras模型,就需要将Keras保存的模型文件(.h5)转换为Tensorflow的模型文件(.pb)。

 

GitHub 已经有大神写了一个转换工具可以很方便的将Keras模型转Tensorflow模型,你只需要输入原模型文件的位置 和 目标模型文件的位置即可。

Keras的模型可以通过model.save() 方法保存为一个单独的模型文件(model.h5),此模型文件包含网络结构和权重参数。

此种模型可通过以下代码将Keras模型转换为Tensorflow模型:

python keras_to_tensorflow.py 
    --input_model="path/to/keras/model.h5" 
    --output_model="path/to/save/model.pb"

Keras的模型也可以通过 model.to_json() model.save_weight()  来分开保存模型的结构(model.json)和 权重参数(weights.h5)。

此种模型可通过以下代码将Keras模型转换为Tensorflow模型:

python keras_to_tensorflow.py 
    --input_model="path/to/keras/model.h5" 
    --input_model_json="path/to/keras/model.json" 
    --output_model="path/to/save/model.pb"

转换之后便可在Java中愉快的对model.pb进行调用了。

java端的调用代码可使用官方提供的LabelImage,用起来很舒服,只需要替换代码里需要指定模型输入和输出的名称即可。

部分代码如下:

private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {// 调用模型graphDef
																						// 并输入图片image
																						// 进行预测
		long invokeGraphPre = System.currentTimeMillis();																				// ,得到结果
		try (Graph g = new Graph()) {
			g.importGraphDef(graphDef);
		long invokeGraphAft = System.currentTimeMillis();	
		logger.debug("加载模型用时:"+(invokeGraphAft - invokeGraphPre));
			try (Session s = new Session(g);
					// Generally, there may be multiple output tensors, all of
					// them must be closed to prevent resource leaks.
					
					Tensor result = s.runner().feed("inception_v3_input", image).fetch("dense_3/Softmax").run()
							.get(0).expect(Float.class)) {
				final long[] rshape = result.shape();
				if (result.numDimensions() != 2 || rshape[0] != 1) {
					throw new RuntimeException(String.format(
							"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
							Arrays.toString(rshape)));
				}
				int nlabels = (int) rshape[1];
				return result.copyTo(new float[1][nlabels])[0];
			}
		}
	}

 

 

遇到的问题:

Java 在调用转换之后的model.pb 的时候,预测结果混乱,本来是五分类的模型,实际预测结果绝大部分仅在两个分类之间跳动,而且和标签基本不一致。起初,找不到原因,在GitHub上看到有相同问题,起初认为是转换工具的原因,就尝试用DL4J去直接调用Keras模型,各种报错,最后发现是因为同事给我Keras模型是嵌套的,DL4J暂不支持嵌套模型的调用。于是回头继续研究之前的做法,偶然间发现有人在转换之后预测结果混乱的原因是图片格式问题,灵机一动,去查看LabelImage示例代码中:

 private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
    try (Graph g = new Graph()) {
      GraphBuilder b = new GraphBuilder(g);
      // Some constants specific to the pre-trained model at:
      // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
      //
      // - The model was trained with images scaled to 224x224 pixels.
      // - The colors, represented as R, G, B in 1-byte each were converted to
      //   float using (value - Mean)/Scale.
      final int H = 224;
      final int W = 224;
      final float mean = 117f;
      final float scale = 1f;

      // Since the graph is being constructed once per execution here, we can use a constant for the
      // input image. If the graph were to be re-used for multiple input images, a placeholder would
      // have been more appropriate.
      final Output input = b.constant("input", imageBytes);
      final Output output =
          b.div(
              b.sub(
                  b.resizeBilinear(
                      b.expandDims(
                          b.cast(b.decodeJpeg(input, 3), Float.class),
                          b.constant("make_batch", 0)),
                      b.constant("size", new int[] {H, W})),
                  b.constant("mean", mean)),
              b.constant("scale", scale));
      try (Session s = new Session(g)) {
        // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
      }
    }
  }

mean和scale的设置为117f、1f。

将两个参数全部改为128f之后,预测结果恢复正常!完全匹配。

work done!

你可能感兴趣的:(java调用keras离线训练的图片识别模型进行在线预测)