Faster RCNN 推理 从头写 java (四) Classifier 网络预测

目录:

  • 1. 图片预处理
  • 2. RPN网络预测
  • 3. RPN to ROIs
  • 4. Classifier 网络预测
  • 5. Classifier网络输出对 ROIs过滤与修正
  • 6. NMS (非最大值抑制)
  • 7. 坐标转换为原始图片维度

一: 输入输出

输入:

  • ROIs: RPN to ROI 后 没32个为一组的ROIs, shape为 [1, 32, 4]
  • feature: RPN 层的输出, 也就是VGG16的feature map, shape 为 [1, 37, 50, 512]

输出:

  • P_cls: 每个ROI的概率 shape为 [1, 32, 2]
  • P_regr: 每个ROI的回归值, shape 为 [1, 37, 50, 4]

二: 流程

  • 预测

三: code by code

ROIs, feature 转换为tensorflow 的 Tensor

if (featureMap.dataType() != DataType.FLOAT) featureMap = featureMap.castTo(DataType.FLOAT);
Tensor feature_input = TypeConvertor.ndarrayToTensor(featureMap);

if (ROIs.dataType() != DataType.FLOAT) ROIs = ROIs.castTo(DataType.FLOAT);
Tensor ROIs_input = TypeConvertor.ndarrayToTensor(ROIs);

Classifier 网络模型预测

List> output = this.session.runner().
        feed(INPUT_FEATURE_NAME, feature_input).feed(INPUT_ROI_NAME, ROIs_input).
        fetch(OUTPUT_CLS_NAME).fetch(OUTPUT_REG_NAME).
        run();

构建输出
0: P_cls
1: P_regr

return new FasterRCnnClassifier_Output(output.get(0), output.get(1));

你可能感兴趣的:(Faster RCNN 推理 从头写 java (四) Classifier 网络预测)