Deteron2 Faster-RCNN 代码阅读笔记

Deteron2 Faster-RCNN 代码阅读笔记

整体结构

在 Detectron2 中 Faster/Mask RCNN 系列是通过 GeneralizedRCNN 来实现的,代码位于 detectron2\modeling\meta_arch\rcnn.py 。类的关系入下图所示
Deteron2 Faster-RCNN 代码阅读笔记_第1张图片

GeneralizedRCNN 由三部分组成 : backbone、proposal_generator 和 roi_heads,分别通过 build_backbone,build_proposal_generator, build_roi_heads 构建。初始化的部分代码如下。目前 Detectron2 支持的 Backbone 有 Resnet50(+FPN), Rest101(+FPN), proposal_generatordetectron2\modeling\proposal_generator\rpn.py 中的 RPN 提供, roi_headsdetectron2\modeling\roi_heads\roi_heads.py 中的 StandardROIHeads 提供。

class GeneralizedRCNN(nn.Module):
    """
    Generalized R-CNN. Any models that contains the following three components:
    1. Per-image feature extraction (aka backbone)
    2. Region proposal generation
    3. Per-region feature extraction and prediction
    """

    def __init__(self, cfg):
        super().__init__()

        self.device = torch.device(cfg.MODEL.DEVICE)
        self.backbone = build_backbone(cfg)
        self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
        self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
        self.vis_period = cfg.VIS_PERIOD
        self.input_format = cfg.INPUT.FORMAT

GeneralizedRCNN 的前向传播过程为过程为

preprocess_image
backone feature
proposal_generator
roi_heads

输入参数为 bachted_inputs,类型是字典,字典中包括 image, instances, proposals, height, width。返回值也是一个字典,包括 pred_boxes, pred_classes, scores, pred_masks, pred_keypoints.

你可能感兴趣的:(Detectron2,深度学习)