mmdetection中assign和sample结构的简要解析

前言

最近正在学习mmdetection中assign和sample的相关结构,故写此篇博客做一些简单的介绍。

正文

assign

在mmdet中,官方提供了许多种anchor与bbox的匹配机制,其中包括:max_iou_assigner,atss_assigner等等。但是不论这些匹配细则发生怎样的变化,但是有一条语句却始终没有变化:

return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
# 首先解释一下各参数的含义:
# num_gt:为真实边界框的数量

# assigned_gt_inds:shape=(num_anchors),代表的是一张图片中所铺设的所有anchor
# 对应的满足条件而被分配到的真实边界框,如果没被分配到则为0,
# 其余被分配到的则从1开始标记出对应的gt

# max_overlaps:shape=(num_gts, num_anchors),具体为anchor与gt通过对应分配规则而得到的数值

# labels:shape=(num_anchors),通常有两种形式,当传入参数gt_label为None的时候,
# 背景被分配-1,前景被分配0,
# 如果传入的gt_label不为None的时候前景则按照对应的种类所对应的ind进行分配

而关于AssignResults类,分析起来其实并不麻烦其大体结构主要如下:

class AssignResult(util_mixins.NiceRepr):
	# 初始化
    def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
        ...
	
	# 返回anchor的数量
    @property
    def num_preds(self):
        return len(self.gt_inds)
	
	# 用户自定义的属性
    def set_extra_property(self, key, value):
        assert key not in self.info
        self._extra_properties[key] = value

    def get_extra_property(self, key):
        """Get user-defined property."""
        return self._extra_properties.get(key, None)

	# 返回标准信息
    @property
    def info(self):
        ...

	# 规定信息的格式
    def __nice__(self):
        ...

	用于用户自定义的asignner后的debugg
    @classmethod
    def random(cls, **kwargs):
       ...

	# 与sample配合,如果采样的roi中需要将gt也包含进去,则使用下述方法
    def add_gt_(self, gt_labels):
        self_inds = torch.arange(
            1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
        self.gt_inds = torch.cat([self_inds, self.gt_inds])

        self.max_overlaps = torch.cat(
            [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])

        if self.labels is not None:
            self.labels = torch.cat([gt_labels, self.labels])

sample

当已经分配完毕之后,分配的结果类,就会被传入到采样类中进行进一步的处理,在这里就需要提到所有采样类的父类base_sampler,采样的核心思想就是收集正负样本,所以分析起来也很简单。

class BaseSampler(metaclass=ABCMeta):
	# 初始化
    def __init__(self,
                 num,
                 pos_fraction,
                 neg_pos_ub=-1,
                 add_gt_as_proposals=True,
                 **kwargs):
        self.num = num
        self.pos_fraction = pos_fraction
        self.neg_pos_ub = neg_pos_ub
        self.add_gt_as_proposals = add_gt_as_proposals
        self.pos_sampler = self
        self.neg_sampler = self

	# 对正样本的采样,在随机采样中如果正样本不足的话,就会将所有的正样本全部保留
    @abstractmethod
    def _sample_pos(self, assign_result, num_expected, **kwargs):
        """Sample positive samples."""
        pass
	# 负样本采样
    @abstractmethod
    def _sample_neg(self, assign_result, num_expected, **kwargs):
        """Sample negative samples."""
        pass

    def sample(self,
               assign_result,
               bboxes,
               gt_bboxes,
               gt_labels=None,
               **kwargs):
     	 ...

        sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
                                         assign_result, gt_flags)
        return sampling_result

最终返回SamplingResult,其实该类的最终作用就是储存正负样本索引以及anchor和gt_bboxes。

结束

其实要是单纯分析结构的话,还是很好理解的,关键还是需要分析合适的分配规则,待我日后慢慢填坑。如分析有误,欢迎指正!!!!

你可能感兴趣的:(mmdetection,pytorch,python,深度学习)