最近正在学习mmdetection中assign和sample的相关结构,故写此篇博客做一些简单的介绍。
在mmdet中,官方提供了许多种anchor与bbox的匹配机制,其中包括:max_iou_assigner,atss_assigner等等。但是不论这些匹配细则发生怎样的变化,但是有一条语句却始终没有变化:
return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
# 首先解释一下各参数的含义:
# num_gt:为真实边界框的数量
# assigned_gt_inds:shape=(num_anchors),代表的是一张图片中所铺设的所有anchor
# 对应的满足条件而被分配到的真实边界框,如果没被分配到则为0,
# 其余被分配到的则从1开始标记出对应的gt
# max_overlaps:shape=(num_gts, num_anchors),具体为anchor与gt通过对应分配规则而得到的数值
# labels:shape=(num_anchors),通常有两种形式,当传入参数gt_label为None的时候,
# 背景被分配-1,前景被分配0,
# 如果传入的gt_label不为None的时候前景则按照对应的种类所对应的ind进行分配
而关于AssignResults类,分析起来其实并不麻烦其大体结构主要如下:
class AssignResult(util_mixins.NiceRepr):
# 初始化
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
...
# 返回anchor的数量
@property
def num_preds(self):
return len(self.gt_inds)
# 用户自定义的属性
def set_extra_property(self, key, value):
assert key not in self.info
self._extra_properties[key] = value
def get_extra_property(self, key):
"""Get user-defined property."""
return self._extra_properties.get(key, None)
# 返回标准信息
@property
def info(self):
...
# 规定信息的格式
def __nice__(self):
...
用于用户自定义的asignner后的debugg
@classmethod
def random(cls, **kwargs):
...
# 与sample配合,如果采样的roi中需要将gt也包含进去,则使用下述方法
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
当已经分配完毕之后,分配的结果类,就会被传入到采样类中进行进一步的处理,在这里就需要提到所有采样类的父类base_sampler,采样的核心思想就是收集正负样本,所以分析起来也很简单。
class BaseSampler(metaclass=ABCMeta):
# 初始化
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
# 对正样本的采样,在随机采样中如果正样本不足的话,就会将所有的正样本全部保留
@abstractmethod
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive samples."""
pass
# 负样本采样
@abstractmethod
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative samples."""
pass
def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs):
...
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
return sampling_result
最终返回SamplingResult,其实该类的最终作用就是储存正负样本索引以及anchor和gt_bboxes。
其实要是单纯分析结构的话,还是很好理解的,关键还是需要分析合适的分配规则,待我日后慢慢填坑。如分析有误,欢迎指正!!!!