class TaskAlignedAssigner(nn.Module):
"""TOOD: Task-aligned One-stage Object Detection
def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9, num_classes=80):
super(TaskAlignedAssigner, self).__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
self.eps = eps
self.num_classes = num_classes
def forward(self,
r"""This code is based on
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free detector
only can predict positive distance)
4. if an anchor box is assigned to multiple gts, the one with the
highest iou will be selected.
pred_scores (Tensor, float32): 预测的类别概率, shape(B, L, C)
pred_bboxes (Tensor, float32): 预测的box, shape(B, L, 4)
anchor_points (Tensor, float32): 预定义的anchors, shape(L, 2), "cxcy" format
num_anchors_list (List): 每一层anchor的数量, shape(L)
gt_labels (Tensor, int64|int32): 真实框的标签, shape(B, n, 1)
gt_bboxes (Tensor, float32): 真实框, shape(B, n, 4)
pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
bg_index (int): background index用于标识背景
gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)真实值的置信度
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
assert pred_scores.ndim == pred_bboxes.ndim
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
batch_size, num_anchors, num_classes = pred_scores.shape
_, num_max_boxes, _ = gt_bboxes.shape
# negative batch 负样本
if num_max_boxes == 0:
assigned_labels = torch.full([batch_size, num_anchors], bg_index)
assigned_bboxes = torch.zeros([batch_size, num_anchors, 4])
assigned_scores = torch.zeros(
[batch_size, num_anchors, num_classes])
return assigned_labels, assigned_bboxes, assigned_scores
# compute iou between gt and pred bbox, [B, n, L]
# 计算iou距离矩阵
ious = iou_similarity(gt_bboxes, pred_bboxes)
# gather pred bboxes class score
pred_scores = pred_scores.permute(0, 2, 1)# B, C, L
gt_labels = gt_labels.long()# B, n, 1
# 需要简单的代码来替换for循环
batch_ind = torch.arange(
end=batch_size, dtype=gt_labels.dtype, device=pred_scores.device).unsqueeze(-1)# B, 1
bbox_cls_scores = torch.zeros((batch_size, num_max_boxes, num_anchors), dtype=torch.float, device=pred_scores.device)# B, n, L
for i in range(batch_size):
bbox_cls_scores[i] = pred_scores[i, gt_labels[i].squeeze(-1)]
# bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
# compute alignment metrics, [B, n, L]
alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
# check the positive sample's center in gt, [B, n, L]
# 选择在实际框中的中心anchor坐标
is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
# 对每个真值选择 topk 个候选框
is_in_topk = gather_topk_anchors(
alignment_metrics * is_in_gts,
topk_mask=pad_gt_mask.repeat([1, 1, self.topk]).to(torch.bool))
# select positive sample, [B, n, L]
# 正样本的mask矩阵
mask_positive = is_in_topk * is_in_gts * pad_gt_mask
# if an anchor box is assigned to multiple gts,
# the one with the highest iou will be selected, [B, n, L]
# 如果一个anchor被划分给多个真值,只选最高IOU的
mask_positive_sum = mask_positive.sum(axis=-2)
if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).repeat(
[1, num_max_boxes, 1])
is_max_iou = compute_max_iou_anchor(ious)
mask_positive = torch.where(mask_multiple_gts, is_max_iou,
mask_positive_sum = mask_positive.sum(axis=-2)
assigned_gt_index = mask_positive.argmax(axis=-2)
# assigned target
# 已分配目标
assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes#配合gt_labels取值
assigned_labels = gt_labels.flatten()[assigned_gt_index]
assigned_labels = torch.where(
mask_positive_sum > 0, assigned_labels,
torch.full_like(assigned_labels, bg_index))
assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_index]
assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
ind = list(range(num_classes + 1))
assigned_scores = assigned_scores[:, :, :bg_index]
# rescale alignment metrics
alignment_metrics *= mask_positive
max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)[0]
max_ious_per_instance = (ious * mask_positive).max(axis=-1,
alignment_metrics = alignment_metrics / (
max_metrics_per_instance + self.eps) * max_ious_per_instance
alignment_metrics = alignment_metrics.max(-2)[0].unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores