下面我们再来看看函数里面定义的损失函数。
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def loss(self,
cls_scores,
bbox_preds,
centernesses,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
'''
cls_scores: [5][batchsize,80,H_i,W_i]
bbox_preds: [5][batchsize,4,H_i,W_i]
centernesses: [5][batchsize,1,H_i,W_i]
gt_bboxes: [batchsize][num_obj,4]
gt_labels: [batchsize][num_obj]
img_metas: [batchsize][(dict)dict_keys(['filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'img_norm_cfg'])]
cfg: {'assigner': {'type': 'MaxIoUAssigner', 'pos_iou_thr': 0.5, 'neg_iou_thr': 0.4, 'min_pos_iou': 0, 'ignore_iof_thr': -1}, 'allowed_border': -1, 'pos_weight': -1, 'debug': False}
'''
assert len(cls_scores) == len(bbox_preds) == len(centernesses) # 5
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # P3-P7特征图的大小
'''
[torch.Size([100, 152]),
torch.Size([50, 76]),
torch.Size([25, 38]),
torch.Size([13, 19]),
torch.Size([7, 10])]
'''
# 特征图的大小就相当于把原图分为多大的grid,特征图每个像素映射到原图就是该grid的中心点,不同大小的特征图就有不同的grid
# bbox_preds[0].dtype:torch.float32
# all_level_points:(list) [5][n_points][2]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
gt_labels)
'''
labels:[5][batch_size*level_points_i]
bbox_targets:[5][batch_size*level_points_i,4]
'''
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds and centerness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores) # torch.Size([89600, 80]) 所有图片所有point的5个层的输出
flatten_bbox_preds = torch.cat(flatten_bbox_preds) # torch.Size([89600, 4])
flatten_centerness = torch.cat(flatten_centerness) # torch.Size([89600])
flatten_labels = torch.cat(labels) # torch.Size([89600])
flatten_bbox_targets = torch.cat(bbox_targets) # torch.Size([89600, 4])
# repeat points to align with bbox_preds
flatten_points = torch.cat(
[points.repeat(num_imgs, 1) for points in all_level_points]) # torch.Size([89600, 2])
pos_inds = flatten_labels.nonzero().reshape(-1)
num_pos = len(pos_inds)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels,
avg_factor=num_pos + num_imgs) # avoid num_pos is 0
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_centerness = flatten_centerness[pos_inds]
if num_pos > 0:
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_centerness_targets = self.centerness_target(pos_bbox_targets)
pos_points = flatten_points[pos_inds]
# 预测的是距离,解码成坐标
pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) # mmdet/core/bbox/transfrom
pos_decoded_target_preds = distance2bbox(pos_points,
pos_bbox_targets)
# centerness weighted iou loss
loss_bbox = self.loss_bbox(
pos_decoded_bbox_preds,
pos_decoded_target_preds,
weight=pos_centerness_targets,
avg_factor=pos_centerness_targets.sum())
loss_centerness = self.loss_centerness(pos_centerness,
pos_centerness_targets)
else:
loss_bbox = pos_bbox_preds.sum()
loss_centerness = pos_centerness.sum()
return dict(
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness)
来看看get_points的细节
def get_points(self, featmap_sizes, dtype, device):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
self.get_points_single(featmap_sizes[i], self.strides[i],
dtype, device))
return mlvl_points
def get_points_single(self, featmap_size, stride, dtype, device):
h, w = featmap_size # eg 100,152
x_range = torch.arange(
0, w * stride, stride, dtype=dtype, device=device)
'''
tensor([ 0., 8., 16., 24., 32., 40., 48., 56., 64., 72.,
80., 88., 96., 104., 112., 120., 128., 136., 144., 152.,
160., 168., 176., 184., 192., 200., 208., 216., 224., 232.,
240., 248., 256., 264., 272., 280., 288., 296., 304., 312.,
320., 328., 336., 344., 352., 360., 368., 376., 384., 392.,
400., 408., 416., 424., 432., 440., 448., 456., 464., 472.,
480., 488., 496., 504., 512., 520., 528., 536., 544., 552.,
560., 568., 576., 584., 592., 600., 608., 616., 624., 632.,
640., 648., 656., 664., 672., 680., 688., 696., 704., 712.,
720., 728., 736., 744., 752., 760., 768., 776., 784., 792.,
800., 808., 816., 824., 832., 840., 848., 856., 864., 872.,
880., 888., 896., 904., 912., 920., 928., 936., 944., 952.,
960., 968., 976., 984., 992., 1000., 1008., 1016., 1024., 1032.,
1040., 1048., 1056., 1064., 1072., 1080., 1088., 1096., 1104., 1112.,
1120., 1128., 1136., 1144., 1152., 1160., 1168., 1176., 1184., 1192.,
1200., 1208.], device='cuda:0')
'''
y_range = torch.arange(
0, h * stride, stride, dtype=dtype, device=device)
'''
tensor([ 0., 8., 16., 24., 32., 40., 48., 56., 64., 72., 80., 88.,
96., 104., 112., 120., 128., 136., 144., 152., 160., 168., 176., 184.,
192., 200., 208., 216., 224., 232., 240., 248., 256., 264., 272., 280.,
288., 296., 304., 312., 320., 328., 336., 344., 352., 360., 368., 376.,
384., 392., 400., 408., 416., 424., 432., 440., 448., 456., 464., 472.,
480., 488., 496., 504., 512., 520., 528., 536., 544., 552., 560., 568.,
576., 584., 592., 600., 608., 616., 624., 632., 640., 648., 656., 664.,
672., 680., 688., 696., 704., 712., 720., 728., 736., 744., 752., 760.,
768., 776., 784., 792.], device='cuda:0')
'''
y, x = torch.meshgrid(y_range, x_range)
'''
y
tensor([[ 0., 0., 0., ..., 0., 0., 0.],
[ 8., 8., 8., ..., 8., 8., 8.],
[ 16., 16., 16., ..., 16., 16., 16.],
...,
[776., 776., 776., ..., 776., 776., 776.],
[784., 784., 784., ..., 784., 784., 784.],
[792., 792., 792., ..., 792., 792., 792.]], device='cuda:0')
x
tensor([[ 0., 8., 16., ..., 1192., 1200., 1208.],
[ 0., 8., 16., ..., 1192., 1200., 1208.],
[ 0., 8., 16., ..., 1192., 1200., 1208.],
...,
[ 0., 8., 16., ..., 1192., 1200., 1208.],
[ 0., 8., 16., ..., 1192., 1200., 1208.],
[ 0., 8., 16., ..., 1192., 1200., 1208.]], device='cuda:0')
'''
points = torch.stack(
(x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
'''
tensor([[ 4., 4.],
[ 12., 4.],
[ 20., 4.],
...,
[1196., 796.],
[1204., 796.],
[1212., 796.]], device='cuda:0')
'''
return points
fcos_target就是为各level的特征点(也就是原图上的每个grid的中心点)生成target,正样本的点是中心点在gt box里面,并且满足每层fpn输出大小限制的。最新版的论文提到了center sampling,并不是下方gt box里的都是正样本。
def fcos_target(self, points, gt_bboxes_list, gt_labels_list):
'''
points:(list) [5][n_points][2])
gt_bboxes_list: [batch_size][num_objects,4]
gt_labels_list: [batch_size][num_objects]
'''
assert len(points) == len(self.regress_ranges)
num_levels = len(points) # 5
# expand regress ranges to align with points
expanded_regress_ranges = [
points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
points[i]) for i in range(num_levels)
] # (list)[5][n_points][2])
# concat all levels points and regress ranges
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
concat_points = torch.cat(points, dim=0)
# get labels and bbox_targets of each image
labels_list, bbox_targets_list = multi_apply(
self.fcos_target_single,
gt_bboxes_list,
gt_labels_list,
points=concat_points,
regress_ranges=concat_regress_ranges)
# labels_list:[batch_size][total_points]
# bbox_targets_list: [batch_size][total_points,4]
# split to per img, per level
num_points = [center.size(0) for center in points]
labels_list = [labels.split(num_points, 0) for labels in labels_list] #[batch_size][5][level_points_i]
bbox_targets_list = [
bbox_targets.split(num_points, 0)
for bbox_targets in bbox_targets_list
] #[batch_size][5][level_points_i,4]
# concat per level image
concat_lvl_labels = []
concat_lvl_bbox_targets = []
for i in range(num_levels):
concat_lvl_labels.append(
torch.cat([labels[i] for labels in labels_list]))
concat_lvl_bbox_targets.append(
torch.cat(
[bbox_targets[i] for bbox_targets in bbox_targets_list]))
# concat_lvl_labels:[5][batch_size*level_points_i]
# concat_lvl_bbox_targets:[5][batch_size*level_points_i,4]
return concat_lvl_labels, concat_lvl_bbox_targets
def fcos_target_single(self, gt_bboxes, gt_labels, points, regress_ranges):
'''
gt_bboxes: [num_objects,4] xmin,ymin,xmax,ymax
gt_labels: [num_objects]
points: [5_featuremap_total_points,2]
regress_ranges: [5_featuremap_total_points,2] 距离边框的距离中的最大值所在范围
'''
num_points = points.size(0)
num_gts = gt_labels.size(0)
if num_gts == 0:
return gt_labels.new_zeros(num_points), \
gt_bboxes.new_zeros((num_points, 4))
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
# TODO: figure out why these two are different
# areas = areas[None].expand(num_points, num_gts)
areas = areas[None].repeat(num_points, 1) # [num_points, num_gts]
regress_ranges = regress_ranges[:, None, :].expand(
num_points, num_gts, 2) # [num_points, num_gts, 2]
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) # [num_points, num_gts, 2]
xs, ys = points[:, 0], points[:, 1]
xs = xs[:, None].expand(num_points, num_gts)
ys = ys[:, None].expand(num_points, num_gts)
left = xs - gt_bboxes[..., 0]
right = gt_bboxes[..., 2] - xs
top = ys - gt_bboxes[..., 1]
bottom = gt_bboxes[..., 3] - ys
bbox_targets = torch.stack((left, top, right, bottom), -1) # [num_points, num_gts, 4]
# condition1: inside a gt bbox
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 # 只要任何距离边框距离最小值小于0,说明中心点location在gt box之外
# condition2: limit the regression range for each location
max_regress_distance = bbox_targets.max(-1)[0]
inside_regress_range = (
max_regress_distance >= regress_ranges[..., 0]) & (
max_regress_distance <= regress_ranges[..., 1])
# if there are still more than one objects for a location,
# we choose the one with minimal area
areas[inside_gt_bbox_mask == 0] = INF
areas[inside_regress_range == 0] = INF
min_area, min_area_inds = areas.min(dim=1) # [num_points, num_gts]
labels = gt_labels[min_area_inds] #[num_points]
labels[min_area == INF] = 0
bbox_targets = bbox_targets[range(num_points), min_area_inds]
return labels, bbox_targets
centerness_target就是字面意思,通过下式来表示当前位置和物体中心间的距离,如果越远离,值就越接近0,反之越接近1。但只对正样本计算。
def centerness_target(self, pos_bbox_targets):
# only calculate pos centerness targets, otherwise there may be nan
left_right = pos_bbox_targets[:, [0, 2]]
top_bottom = pos_bbox_targets[:, [1, 3]]
centerness_targets = (
left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness_targets)