代码阅读-deformable DETR (五)


class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
        """ Create the criterion.
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            focal_alpha: alpha in Focal Loss
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha


  1. 计算模型输出和gt之间的二分图匹配;
  2. 对于匹配成功的数据对监督其类别和box


class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).

    def __init__(self,
                 cost_class: float = 1,
                 cost_bbox: float = 1,
                 cost_giou: float = 1):
        """Creates the matcher

            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"


def forward(self, outputs, targets):  # Matcher的推理函数
      with torch.no_grad():
            bs, num_queries = outputs["pred_logits"].shape[:2]

            # We flatten to compute the cost matrices in a batch
            out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

            # Also concat the target labels and boxes
            tgt_ids = torch.cat([v["labels"] for v in targets])
            tgt_bbox = torch.cat([v["boxes"] for v in targets])

            # Compute the classification cost.  # 采用的focal loss
            alpha = 0.25
            gamma = 2.0
            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

            # Compute the L1 cost between boxes
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

            # Compute the giou cost betwen boxes
            cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),

            # Final cost matrix
            C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
            C = C.view(bs, num_queries, -1).cpu()

            sizes = [len(v["boxes"]) for v in targets]  # batch中每个sample中目标的个数
            indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] # 相当于选择每个样本的sample与target的相似度矩阵进行二分匹配
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # 长度为batchsize的元组list

这里首先需要注意的是整个推理过程是不参与梯度反向传导的。其次在刻画预测类别与gt的差异性时使用的是focal loss,且其参数是固定的。最终对batch中每个样本使用匈牙利算法进行二分图匹配,获得对应的索引集合,输出格式是[(第一个样本配对的输出索引集合,第一个样本配对的gt索引集合), ...]




card_pred =  (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)  


在计算labels和box的损失时,出现一个函数_get_src_permutation_idx,这个函数主要是将Matcher返回的多个样本的匹配对索引拉平方便索引。举个例子,batch_size=2, query_num=4, 第一个样本的gt数位2, 第二个样本的gt数为3,那么matcher的返回可能是:
[([0,2], [0, 1]), ([1,3, 0], [2, 0, 1])], _get_src_permutation_idx的返回值idx为一个元组,即[0, 0,1,1,1](即每个匹配对对应的query所在的样本在batch中的索引)和[0, 2,1,3,0](即每个匹配对的query在每个样本所有query中的索引),这样的话
target_classes[idx] 表示选择对应的样本对应的query,进而进行gt赋值。

        target_classes_onehot = target_classes_onehot[:,:,:-1]  # 最后一类是背景类
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]

表示针对于gt为背景的query,其gt是全零向量,因此采用的是sigmoid+F.binary_cross_entropy_with_logits 构建Focal loss,而不是softmax。
这里有个奇怪的地方是 loss_ce有一个系数query_num, 这是应为sigmoid_focal_loss输出有一个query_num上的mean操作,所以这里可以抵消。



