MMDet逐行代码解读之正负样本采样Sampler

文章目录

  • 前言
  • 1、构造一个简单的sampler
  • 2、BaseSampler类
  • 3、RandomSampler类
    • 3.1 sample方法
    • 3.2 _sample_pos方法
    • 3.2 _sample_neg方法
  • 总结

前言

  本篇是MMdet逐行解读第四篇,代码地址:mmdet/core/bbox/samplers/random_sampler.py。随机采样正负样本主要针对在训练过程中,经过MAXIOUAssigner后,确定出每个anchor和哪个gt匹配后,从这些正负样本中采样来进行loss计算。本文以RPN的config进行讲解,因为该部分用到了随机采样来克服正负样本不平衡;而在RetinaNet中则使用focal loss来克服正负样本不平衡问题,即没有随机采样的过程。
历史文章如下:
 AnchorGenerator解读
 MaxIOUAssigner解读
 DeltaXYWHBBoxCoder解读

1、构造一个简单的sampler

    from mmdet.core.bbox import build_sampler
    # 构造一个sampler
    sampler = dict(
        type='RandomSampler',# 构造一个随机采样器
        num=256,             # 正负样本总数量
        pos_fraction=0.5,    # 正样本比例
        neg_pos_ub=-1,       # 负样本上限
        add_gt_as_proposals=False) # 是否添加gt作为正样本,默认不添加。
    sp = build_sampler(sampler)       
    # 这块不必详细理解,知道意思即可。
    # 就是随机生成一个assigner、bboxes和gt_bboxes,
    from mmdet.core.bbox import AssignResult
    from mmdet.core.bbox.demodata import ensure_rng, random_boxes
    rng = ensure_rng(None)
    assign_result = AssignResult.random(rng=rng)
    bboxes = random_boxes(assign_result.num_preds, rng=rng)
    gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
    gt_labels = None
    # 调用sample方法进行正负样本采样
    self = sp.sample(assign_result, bboxes, gt_bboxes, gt_labels)

2、BaseSampler类

class BaseSampler(metaclass=ABCMeta):
    """Base class of samplers"""

    def __init__(self,
                 num,
                 pos_fraction,
                 neg_pos_ub=-1,
                 add_gt_as_proposals=True,
                 **kwargs):
        self.num = num
        self.pos_fraction = pos_fraction
        self.neg_pos_ub = neg_pos_ub
        self.add_gt_as_proposals = add_gt_as_proposals
        self.pos_sampler = self
        self.neg_sampler = self

    @abstractmethod
    def _sample_pos(self, assign_result, num_expected, **kwargs):
        """Sample positive samples"""
        pass

    @abstractmethod
    def _sample_neg(self, assign_result, num_expected, **kwargs):
        """Sample negative samples"""
        pass

    def sample(self,
               assign_result,
               bboxes,
               gt_bboxes,
               gt_labels=None,
               **kwargs):
        pass

  基类比较容易理解,核心是sample方法,内部调用_sample_pos方法和_sample_neg方法。后续继承该类的子类只需实现_sample_pos方法和_sample_neg方法即可。

3、RandomSampler类

3.1 sample方法

 以RandomSampler类来讲解代码。首先看下sample方法:

 # 确定正样本个数: 256*0.5 = 128
 num_expected_pos = int(self.num * self.pos_fraction)
 # 调用_sample_pos方法返回采样后正样本的id。
 pos_inds = self.pos_sampler._sample_pos(
     assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
 pos_inds = pos_inds.unique()    # 挑选出tensor独立不重复元素
 num_sampled_pos = pos_inds.numel()  # 确定出正样本个数
 num_expected_neg = self.num - num_sampled_pos #确定负样本个数
 # 由于该参数为-1,故不执行if语句,即实打实的采样254个负样本
 if self.neg_pos_ub >= 0: 
     _pos = max(1, num_sampled_pos)
     #确定负样本的上限是正样本个数的neg_pos_ub倍
     neg_upper_bound = int(self.neg_pos_ub * _pos) 
      # 负样本的个数不能超过上限      
     if num_expected_neg > neg_upper_bound:             
         num_expected_neg = neg_upper_bound              
 # 调用_sample_neg方法返回采样后负样本的id
 neg_inds = self.neg_sampler._sample_neg(
     assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
 neg_inds = neg_inds.unique() # 同理,将id取集合操作。
 # 用SamplingResult进行封装
 sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,  
                                  assign_result, gt_flags)

  代码还是比较容易理解:首先确定正样本个数,然后完成采样;采样后确定负样本的个数,若指定了负样本上限:neg_upper_bound,则负样本个数最多采样不能超过正样本个数的neg_upper_bound倍;若无指定,则负样本个数就是总的数量-正样本个数。

3.2 _sample_pos方法

  再来看下采样正样本的方法:

def _sample_pos(self, assign_result, num_expected, **kwargs):
    """Randomly sample some positive samples."""
    # 找出非0的正样本的id
    pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) 
    if pos_inds.numel() != 0:   
        pos_inds = pos_inds.squeeze(1)
    # 若正样本个数<期望的128个,则直接返回。
    if pos_inds.numel() <= num_expected: 
        return pos_inds
    # 否则就从pos_inds里面随机采够128个。
    else:
        return self.random_choice(pos_inds, num_expected)

3.2 _sample_neg方法

  和采样正样本方法大同小异,这里看下即可。

def _sample_neg(self, assign_result, num_expected, **kwargs):
    """Randomly sample some negative samples."""
    neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
    if neg_inds.numel() != 0:               
        neg_inds = neg_inds.squeeze(1)
    if len(neg_inds) <= num_expected: # 若负样本数量比期望的还小则直接返回
        return neg_inds
    else:
        return self.random_choice(neg_inds, num_expected)

总结

  下篇将开启model模块介绍,敬请期待。

你可能感兴趣的:(mmcv和mmdet源码注释版,深度学习,python,机器学习)