mmdet中anchor_head为预测分支的基类,包含了_inti_layers, init_weights, forward_single,forward, get_anchors, loss,get_bboxes这些功能,囊括了训练用到的loss计算以及预测用到的get_bboxes方法。
单尺度预测,得出分类,边框预测
def forward_single(self, x):
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred
包含FPN结构涉及到多尺度预测,作者设计一个多输入处理方法multi_apply,该方法核心就是针对输入list每个元素依据func处理得出结果,得出的结果是[(cls,bbox),(cls,bbox)]这样的格式,最后再通过zip做一下同种类别预测的合并操作,输出([cls1,cls2],[bbox1,bbox2])
def multi_apply(func, *args, **kwargs):
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
算loss,涉及到anchor的生成,以及anchor target的生成,损失函数
前两部分代码如下
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
anchor生成之前讲过了,这儿得出的anchor_list是每张图多个尺度的所有anchor,shape为[num_imgs,num_levels, num_anchors*4].valid_flag_list指出哪些anchor是合法的。
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
接下来的重要就是生成anchor的target了,这里又涉及到assign和sampler两个操作,这块单独再解读一下。cls_reg_targets包含anchor的label,weight,bbox targets等这些信息,这些tensor都被转换成以level为主的tensor,如【num_levels,batch_size,num_anchors…】
有了target就可以算loss了,这儿loss也是用multi_apply对不同level分别算,最后汇总。
常见的cls用交叉熵损失,回归用smooth L1损失。
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox
3、get_bboxes
get_bboxes将模型输出转换成预测所得的boxes和labels
输入参数描述如下
Transform network output for a batch into labeled boxes.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
img_metas (list[dict]): size / scale info for each image
cfg (mmcv.Config): test / postprocessing configuration
rescale (bool): if True, return boxes in original image space
Returns:
list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the class index of the
corresponding box.