本文提出了一种用于行人目标检测的标签分配策略,具体来说,主要有以下几步流程。
本文的作者和YOLOX是同一个作者,YOLOX的标签分配策略,可以看做在本文上面进行了稍微的更改。
参考连接
def get_lla_assignments_and_losses(self, shifts, targets, box_cls, box_delta, box_iou):
gt_classes = []
box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]
box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]
box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]
box_cls = torch.cat(box_cls, dim=1)
box_delta = torch.cat(box_delta, dim=1)
box_iou = torch.cat(box_iou, dim=1)
losses_cls = []
losses_box_reg = []
losses_iou = []
num_fg = 0
for shifts_per_image, targets_per_image, box_cls_per_image, \
box_delta_per_image, box_iou_per_image in zip(
shifts, targets, box_cls, box_delta, box_iou):
shifts_over_all = torch.cat(shifts_per_image, dim=0)
gt_boxes = targets_per_image.gt_boxes
gt_classes = targets_per_image.gt_classes
deltas = self.shift2box_transform.get_deltas(
shifts_over_all, gt_boxes.tensor.unsqueeze(1))
is_in_boxes = deltas.min(dim=-1).values > 0.01
shape = (len(targets_per_image), len(shifts_over_all), -1)
box_cls_per_image_unexpanded = box_cls_per_image
box_delta_per_image_unexpanded = box_delta_per_image
box_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape)
gt_cls_per_image = F.one_hot(
torch.max(gt_classes, torch.zeros_like(gt_classes)), self.num_classes
).float().unsqueeze(1).expand(shape)
with torch.no_grad():
loss_cls = sigmoid_focal_loss_jit(
box_cls_per_image,
gt_cls_per_image,
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma).sum(dim=-1)
loss_cls_bg = sigmoid_focal_loss_jit(
box_cls_per_image_unexpanded,
torch.zeros_like(box_cls_per_image_unexpanded),
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma).sum(dim=-1)
box_delta_per_image = box_delta_per_image.unsqueeze(0).expand(shape)
gt_delta_per_image = self.shift2box_transform.get_deltas(
shifts_over_all, gt_boxes.tensor.unsqueeze(1))
loss_delta = iou_loss(
box_delta_per_image,
gt_delta_per_image,
box_mode="ltrb",
loss_type='iou')
ious = get_ious(
box_delta_per_image,
gt_delta_per_image,
box_mode="ltrb",
loss_type='iou')
loss = loss_cls + self.reg_cost * loss_delta + 1e3 * (1 - is_in_boxes.float())
loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)
num_gt = loss.shape[0] - 1
num_anchor = loss.shape[1]
# Topk
matching_matrix = torch.zeros_like(loss)
_, topk_idx = torch.topk(loss[:-1], k=self.topk, dim=1, largest=False)
matching_matrix[torch.arange(num_gt).unsqueeze(1).repeat(1,
self.topk).view(-1), topk_idx.view(-1)] = 1.
# make sure one anchor with one gt
anchor_matched_gt = matching_matrix.sum(0)
if (anchor_matched_gt > 1).sum() > 0:
loss_min, loss_argmin = torch.min(loss[:-1, anchor_matched_gt > 1], dim=0)
matching_matrix[:, anchor_matched_gt > 1] *= 0.
matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1.
anchor_matched_gt = matching_matrix.sum(0)
num_fg += matching_matrix.sum()
matching_matrix[-1] = 1. - anchor_matched_gt # assignment for Background
assigned_gt_inds = torch.argmax(matching_matrix, dim=0)
gt_cls_per_image_bg = gt_cls_per_image.new_zeros(
(gt_cls_per_image.size(1), gt_cls_per_image.size(2))).unsqueeze(0)
gt_cls_per_image_with_bg = torch.cat(
[gt_cls_per_image, gt_cls_per_image_bg], dim=0)
cls_target_per_image = gt_cls_per_image_with_bg[
assigned_gt_inds, torch.arange(num_anchor)]
# Dealing with Crowdhuman ignore label
gt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)])
anchor_cls_labels = gt_classes_[assigned_gt_inds]
valid_flag = anchor_cls_labels >= 0
pos_mask = assigned_gt_inds != len(targets_per_image) # get foreground mask
valid_fg = pos_mask & valid_flag
assigned_fg_inds = assigned_gt_inds[valid_fg]
range_fg = torch.arange(num_anchor)[valid_fg]
ious_fg = ious[assigned_fg_inds, range_fg]
anchor_loss_cls = sigmoid_focal_loss_jit(
box_cls_per_image_unexpanded[valid_flag],
cls_target_per_image[valid_flag],
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma).sum(dim=-1)
delta_target = gt_delta_per_image[assigned_fg_inds, range_fg]
anchor_loss_delta = 2. * iou_loss(
box_delta_per_image_unexpanded[valid_fg],
delta_target,
box_mode="ltrb",
loss_type=self.iou_loss_type)
anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits(
box_iou_per_image.squeeze(1)[valid_fg],
ious_fg,
reduction='none')
losses_cls.append(anchor_loss_cls.sum())
losses_box_reg.append(anchor_loss_delta.sum())
losses_iou.append(anchor_loss_iou.sum())
if self.norm_sync:
dist.all_reduce(num_fg)
num_fg = num_fg.float() / dist.get_world_size()
return {
'loss_cls': torch.stack(losses_cls).sum() / num_fg,
'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg,
'loss_iou': torch.stack(losses_iou).sum() / num_fg
}