class TaskAlignedAssigner(nn.Module):
"""TOOD: Task-aligned One-stage Object Detection
"""
def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9, num_classes=80):
super(TaskAlignedAssigner, self).__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
self.eps = eps
self.num_classes = num_classes
@torch.no_grad()
def forward(self,
pred_scores,
pred_bboxes,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index,
gt_scores=None):
r"""This code is based on
https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/task_aligned_assigner.py
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free detector
only can predict positive distance)
4. if an anchor box is assigned to multiple gts, the one with the
highest iou will be selected.
Args:
pred_scores (Tensor, float32): 预测的类别概率, shape(B, L, C)
pred_bboxes (Tensor, float32): 预测的box, shape(B, L, 4)
anchor_points (Tensor, float32): 预定义的anchors, shape(L, 2), "cxcy" format
num_anchors_list (List): 每一层anchor的数量, shape(L)
gt_labels (Tensor, int64|int32): 真实框的标签, shape(B, n, 1)
gt_bboxes (Tensor, float32): 真实框, shape(B, n, 4)
pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
bg_index (int): background index用于标识背景
gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)真实值的置信度
Returns:
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
"""
#数据验证
assert pred_scores.ndim == pred_bboxes.ndim
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
#获取形状数据
batch_size, num_anchors, num_classes = pred_scores.shape
_, num_max_boxes, _ = gt_bboxes.shape
# negative batch 负样本
if num_max_boxes == 0:
assigned_labels = torch.full([batch_size, num_anchors], bg_index)
assigned_bboxes = torch.zeros([batch_size, num_anchors, 4])
assigned_scores = torch.zeros(
[batch_size, num_anchors, num_classes])
return assigned_labels, assigned_bboxes, assigned_scores
# compute iou between gt and pred bbox, [B, n, L]
# 计算iou距离矩阵
ious = iou_similarity(gt_bboxes, pred_bboxes)
# gather pred bboxes class score
pred_scores = pred_scores.permute(0, 2, 1)# B, C, L
gt_labels = gt_labels.long()# B, n, 1
# 需要简单的代码来替换for循环
batch_ind = torch.arange(
end=batch_size, dtype=gt_labels.dtype, device=pred_scores.device).unsqueeze(-1)# B, 1
bbox_cls_scores = torch.zeros((batch_size, num_max_boxes, num_anchors), dtype=torch.float, device=pred_scores.device)# B, n, L
for i in range(batch_size):
bbox_cls_scores[i] = pred_scores[i, gt_labels[i].squeeze(-1)]
# bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
# compute alignment metrics, [B, n, L]
alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
self.beta)#类别距离*IoU距离,预测值到实际值的距离矩阵
# check the positive sample's center in gt, [B, n, L]
# 选择在实际框中的中心anchor坐标
is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
# 对每个真值选择 topk 个候选框
is_in_topk = gather_topk_anchors(
alignment_metrics * is_in_gts,
self.topk,
topk_mask=pad_gt_mask.repeat([1, 1, self.topk]).to(torch.bool))
# select positive sample, [B, n, L]
# 正样本的mask矩阵
mask_positive = is_in_topk * is_in_gts * pad_gt_mask
# if an anchor box is assigned to multiple gts,
# the one with the highest iou will be selected, [B, n, L]
# 如果一个anchor被划分给多个真值,只选最高IOU的
mask_positive_sum = mask_positive.sum(axis=-2)
if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).repeat(
[1, num_max_boxes, 1])
is_max_iou = compute_max_iou_anchor(ious)
mask_positive = torch.where(mask_multiple_gts, is_max_iou,
mask_positive)
mask_positive_sum = mask_positive.sum(axis=-2)
assigned_gt_index = mask_positive.argmax(axis=-2)
# assigned target
# 已分配目标
assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes#配合gt_labels取值
assigned_labels = gt_labels.flatten()[assigned_gt_index]
assigned_labels = torch.where(
mask_positive_sum > 0, assigned_labels,
torch.full_like(assigned_labels, bg_index))
assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_index]
assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
ind = list(range(num_classes + 1))
ind.remove(bg_index)
assigned_scores = assigned_scores[:, :, :bg_index]
# rescale alignment metrics
alignment_metrics *= mask_positive
max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)[0]
max_ious_per_instance = (ious * mask_positive).max(axis=-1,
keepdim=True)[0]
alignment_metrics = alignment_metrics / (
max_metrics_per_instance + self.eps) * max_ious_per_instance
alignment_metrics = alignment_metrics.max(-2)[0].unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores