Faster RCNN 推理 从头写 java (二) RPN网络预测

目录:

  • 1. 图片预处理
  • 2. RPN网络预测
  • 3. RPN to ROIs
  • 4. Classifier 网络预测
  • 5. Classifier网络输出对 ROIs过滤与修正
  • 6. NMS (非最大值抑制)
  • 7. 坐标转换为原始图片维度

一: 输入输出

输入:

  • omg: 经过预处理过的图像, shape为 [1, 600, 800, 3].

输出:

  • cls: 每个anchor在pixel上的概率, shape为 [1, 37, 50, 49].
  • reg: 每个anchor在pixel上的回归值, shape 为 [1, 37, 50, 196].
  • feature: 经过VGG16后的feature map, shape 为 [1, 37, 50, 512].

二: 流程

  • 图片BGR 格式转换为 RGB 格式。
  • 图片缩放。
  • 图片均值中值化。

三: code by code

img 转换为tensorflow 的 Tensor

Tensor input = TypeConvertor.ndarrayToTensor(img);

预测

List> output = this.session.runner().
        feed(INPUT_NAME, input).
        fetch(OUTPUT_CLS_NAME).fetch(OUTPUT_REG_NAME).fetch(OUTPUT_FEATURE_MAP_NAME).
        run();

构建输出
0: cls
1: reg
3: feature

return new FasterRCnnRPN_Output(output.get(0), output.get(1), output.get(2));

你可能感兴趣的:(Faster RCNN 推理 从头写 java (二) RPN网络预测)