mmdetection中的anchor部分的代码理解和注释

/home/wuchenxi/mmdetection/mmdet/models/anchor_heads/anchor_head.py

from future import division

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, multiclass_nms, force_fp32)
from …builder import build_loss
from …registry import HEADS

@HEADS.register_module
class AnchorHead(nn.Module):
“”"Anchor-based head (RPN, RetinaNet, SSD, etc.).

Args:
    in_channels (int): Number of channels in the input feature map.
    feat_channels (int): Number of channels of the feature map.
    anchor_scales (Iterable): Anchor scales.
    anchor_ratios (Iterable): Anchor aspect ratios.
    anchor_strides (Iterable): Anchor strides.
    anchor_base_sizes (Iterable): Anchor base sizes.
    target_means (Iterable): Mean values of regression targets.
    target_stds (Iterable): Std values of regression targets.
    loss_cls (dict): Config of classification loss.
    loss_bbox (dict): Config of localization loss.
"""  # noqa: W605

def __init__(self,
             num_classes,
             in_channels,
             feat_channels=256,
             anchor_scales=[8, 16, 32],
             anchor_ratios=[0.5, 1.0, 2.0],#1:2,1:1,2:1的anchor比例
             anchor_strides=[4, 8, 16, 32, 64],    #初始的base_size就是通过这个生成的,个数应和FPN的输出层数相等
             anchor_base_sizes=None,
             target_means=(.0, .0, .0, .0),
             target_stds=(1.0, 1.0, 1.0, 1.0),
             loss_cls=dict(
                 type='CrossEntropyLoss',
                 use_sigmoid=True,
                 loss_weight=1.0),
             loss_bbox=dict(
                 type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
    super(AnchorHead, self).__init__()
    self.in_channels = in_channels
    self.num_classes = num_classes
    self.feat_channels = feat_channels
    self.anchor_scales = anchor_scales
    self.anchor_ratios = anchor_ratios
    self.anchor_strides = anchor_strides
    self.anchor_base_sizes = list(
        anchor_strides) if anchor_base_sizes is None else anchor_base_sizes  #生成base_size部分代码
    self.target_means = target_means
    self.target_stds = target_stds

    self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
    self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC']
    if self.use_sigmoid_cls:
        self.cls_out_channels = num_classes - 1
    else:
        self.cls_out_channels = num_classes
    self.loss_cls = build_loss(loss_cls)
    self.loss_bbox = build_loss(loss_bbox)
    self.fp16_enabled = False

    self.anchor_generators = []
    for anchor_base in self.anchor_base_sizes:
        self.anchor_generators.append(
            AnchorGenerator(anchor_base, anchor_scales, anchor_ratios)) #在anchor_base_sizes上做for循环
    self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales) #,程序转到anchorgenerator运行
    self._init_layers()

def _init_layers(self):
    self.conv_cls = nn.Conv2d(self.feat_channels,
                              self.num_anchors * self.cls_out_channels, 1)
    self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)

def init_weights(self):
    normal_init(self.conv_cls, std=0.01)
    normal_init(self.conv_reg, std=0.01)

def forward_single(self, x):
    cls_score = self.conv_cls(x)
    bbox_pred = self.conv_reg(x)
    return cls_score, bbox_pred

def forward(self, feats):
    return multi_apply(self.forward_single, feats)

def get_anchors(self, featmap_sizes, img_metas):
    """Get anchors according to feature map sizes.

    Args:
        featmap_sizes (list[tuple]): Multi-level feature map sizes.
        img_metas (list[dict]): Image meta info.

    Returns:
        tuple: anchors of each image, valid flags of each image
    """
    num_imgs = len(img_metas)
    num_levels = len(featmap_sizes)

    # since feature map sizes of all images are the same, we only compute
    # anchors for one time
    multi_level_anchors = []
    for i in range(num_levels):
        anchors = self.anchor_generators[i].grid_anchors(
            featmap_sizes[i], self.anchor_strides[i])
        multi_level_anchors.append(anchors)
    anchor_list = [multi_level_anchors for _ in range(num_imgs)]

    # for each image, we compute valid flags of multi level anchors
    valid_flag_list = []
    for img_id, img_meta in enumerate(img_metas):
        multi_level_flags = []
        for i in range(num_levels):
            anchor_stride = self.anchor_strides[i]
            feat_h, feat_w = featmap_sizes[i]
            h, w, _ = img_meta['pad_shape']
            valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
            valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
            flags = self.anchor_generators[i].valid_flags(
                (feat_h, feat_w), (valid_feat_h, valid_feat_w))
            multi_level_flags.append(flags)
        valid_flag_list.append(multi_level_flags)

    return anchor_list, valid_flag_list

def loss_single(self, cls_score, bbox_pred, labels, label_weights,
                bbox_targets, bbox_weights, num_total_samples, cfg):
    # classification loss
    labels = labels.reshape(-1)
    label_weights = label_weights.reshape(-1)
    cls_score = cls_score.permute(0, 2, 3,
                                  1).reshape(-1, self.cls_out_channels)
    loss_cls = self.loss_cls(
        cls_score, labels, label_weights, avg_factor=num_total_samples)
    # regression loss
    bbox_targets = bbox_targets.reshape(-1, 4)
    bbox_weights = bbox_weights.reshape(-1, 4)
    bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
    loss_bbox = self.loss_bbox(
        bbox_pred,
        bbox_targets,
        bbox_weights,
        avg_factor=num_total_samples)
    return loss_cls, loss_bbox

@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
         cls_scores,
         bbox_preds,
         gt_bboxes,
         gt_labels,
         img_metas,
         cfg,
         gt_bboxes_ignore=None):
    featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
    assert len(featmap_sizes) == len(self.anchor_generators)

    anchor_list, valid_flag_list = self.get_anchors(
        featmap_sizes, img_metas)
    label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
    cls_reg_targets = anchor_target(
        anchor_list,
        valid_flag_list,
        gt_bboxes,
        img_metas,
        self.target_means,
        self.target_stds,
        cfg,
        gt_bboxes_ignore_list=gt_bboxes_ignore,
        gt_labels_list=gt_labels,
        label_channels=label_channels,
        sampling=self.sampling)
    if cls_reg_targets is None:
        return None
    (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
     num_total_pos, num_total_neg) = cls_reg_targets
    num_total_samples = (
        num_total_pos + num_total_neg if self.sampling else num_total_pos)
    losses_cls, losses_bbox = multi_apply(
        self.loss_single,
        cls_scores,
        bbox_preds,
        labels_list,
        label_weights_list,
        bbox_targets_list,
        bbox_weights_list,
        num_total_samples=num_total_samples,
        cfg=cfg)
    return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)

@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
               rescale=False):
    assert len(cls_scores) == len(bbox_preds)
    num_levels = len(cls_scores)

    mlvl_anchors = [
        self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
                                               self.anchor_strides[i])
        for i in range(num_levels)
    ]
    result_list = []
    for img_id in range(len(img_metas)):
        cls_score_list = [
            cls_scores[i][img_id].detach() for i in range(num_levels)
        ]
        bbox_pred_list = [
            bbox_preds[i][img_id].detach() for i in range(num_levels)
        ]
        img_shape = img_metas[img_id]['img_shape']
        scale_factor = img_metas[img_id]['scale_factor']
        proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
                                           mlvl_anchors, img_shape,
                                           scale_factor, cfg, rescale)
        result_list.append(proposals)
    return result_list

def get_bboxes_single(self,
                      cls_scores,
                      bbox_preds,
                      mlvl_anchors,
                      img_shape,
                      scale_factor,
                      cfg,
                      rescale=False):
    assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
    mlvl_bboxes = []
    mlvl_scores = []
    for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
                                             mlvl_anchors):
        assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
        cls_score = cls_score.permute(1, 2,
                                      0).reshape(-1, self.cls_out_channels)
        if self.use_sigmoid_cls:
            scores = cls_score.sigmoid()
        else:
            scores = cls_score.softmax(-1)
        bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
        nms_pre = cfg.get('nms_pre', -1)
        if nms_pre > 0 and scores.shape[0] > nms_pre:
            if self.use_sigmoid_cls:
                max_scores, _ = scores.max(dim=1)
            else:
                max_scores, _ = scores[:, 1:].max(dim=1)
            _, topk_inds = max_scores.topk(nms_pre)
            anchors = anchors[topk_inds, :]
            bbox_pred = bbox_pred[topk_inds, :]
            scores = scores[topk_inds, :]
        bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
                            self.target_stds, img_shape)
        mlvl_bboxes.append(bboxes)
        mlvl_scores.append(scores)
    mlvl_bboxes = torch.cat(mlvl_bboxes)
    if rescale:
        mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
    mlvl_scores = torch.cat(mlvl_scores)
    if self.use_sigmoid_cls:
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
    det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                            cfg.score_thr, cfg.nms,
                                            cfg.max_per_img)
    return det_bboxes, det_labels

/home/wuchenxi/mmdetection/mmdet/core/anchor/anchor_generator.py

import torch

class AnchorGenerator(object):

def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
    self.base_size = base_size
    self.scales = torch.Tensor(scales)
    self.ratios = torch.Tensor(ratios)
    self.scale_major = scale_major
    self.ctr = ctr
    self.base_anchors = self.gen_base_anchors()

@property
def num_base_anchors(self):
    return self.base_anchors.size(0)

def gen_base_anchors(self):
    w = self.base_size     #在上面的程序上定义的  anchor_strides=[4, 8, 16, 32, 64] 的值传递过来
    h = self.base_size
    if self.ctr is None:
        x_ctr = 0.5 * (w - 1) #anchor的中心点由w,h确定,w-1的目的目前还不确定
        y_ctr = 0.5 * (h - 1)
    else:
        x_ctr, y_ctr = self.ctr

    h_ratios = torch.sqrt(self.ratios)   #根据ratios设定宽高比
    w_ratios = 1 / h_ratios
    if self.scale_major:
        ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)  #view(-1)是展平成一维
        hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)
    else:
        ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
        hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)

    base_anchors = torch.stack(
        [
            x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
            x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
        ],
        dim=-1).round()

    return base_anchors

def _meshgrid(self, x, y, row_major=True):
    xx = x.repeat(len(y))
    yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
    if row_major:
        return xx, yy
    else:
        return yy, xx

def grid_anchors(self, featmap_size, stride=16, device='cuda'):
    base_anchors = self.base_anchors.to(device)

    feat_h, feat_w = featmap_size
    shift_x = torch.arange(0, feat_w, device=device) * stride
    shift_y = torch.arange(0, feat_h, device=device) * stride
    shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
    shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
    shifts = shifts.type_as(base_anchors)
    # first feat_w elements correspond to the first row of shifts
    # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
    # shifted anchors (K, A, 4), reshape to (K*A, 4)

    all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
    all_anchors = all_anchors.view(-1, 4)
    # first A rows correspond to A anchors of (0, 0) in feature map,
    # then (0, 1), (0, 2), ...
    return all_anchors

def valid_flags(self, featmap_size, valid_size, device='cuda'):
    feat_h, feat_w = featmap_size
    valid_h, valid_w = valid_size
    assert valid_h <= feat_h and valid_w <= feat_w
    valid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device)
    valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device)
    valid_x[:valid_w] = 1
    valid_y[:valid_h] = 1
    valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
    valid = valid_xx & valid_yy
    valid = valid[:, None].expand(
        valid.size(0), self.num_base_anchors).contiguous().view(-1)
    return valid

你可能感兴趣的:(源码笔记)