Transformer实战-系列教程21:DETR 源码解读8 损失计算:(SetCriterion类)

Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8 损失计算:(SetCriterion类)

14、SetCriterion类

位置:DETR/models/detr.py/SetCriterion类

这个类专门用来计算DETR的各种损失,主要包含两步:
1、计算真实边界框和模型输出之间的匈牙利匹配
2、对每一对匹配的真实目标/预测(监督类别和边界框)进行监督

DETR需要计算的损失包含:
1、分类损失,80类别
2、回归框的损失,x、y、w、h,以及GIU

假如标签有5个框,100个框里面就需要挑选出5个来进行计算损失,可是100个里面应该挑选那个呢?使用匈牙利匹配算法,首先将100个框中首先预测的物体要跟标签的5个框中的预测物体是对应的那些框找出来,然后找出对应损失最小的5个框,这就是匈牙利算法

14.1 构造函数

class SetCriterion(nn.Module):
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)
  1. 定义类,继承nn.Module
  2. 构造函数,传入5个参数
  3. 初始化
  4. num_classes,目标类别的数量,不包括特殊的“无对象”类别
  5. atcher,一个模块,能够计算目标和模型输出之间的匹配
  6. weight_dict,一个字典,包含各种损失名称及其相对权重
  7. eos_coef,相对分类权重,应用于“无对象”类别
  8. losses,一个列表,包含要应用的所有损失的名称
  9. empty_weight ,创建一个长度为num_classes + 1的张量empty_weight,并将所有元素初始化为1。这个张量用于调整每个类别的权重,包括特殊的无对象(end-of-sequence, EOS)类别
  10. empty_weight[-1], 将empty_weight张量中最后一个元素(对应于无对象类别)的值设置为eos_coef。这样做是为了在损失计算中对无对象类别给予不同的权重,通常这个权重较小,因为无对象类别通常比其他对象类别更频繁地出现
  11. register_buffer,使用register_buffer方法注册empty_weight张量为一个缓冲区。在PyTorch中,缓冲区是模块的一部分,其内容在模型保存和加载时会被保存和加载,但它们不是模型参数,因此在训练过程中不会被优化器更新。这对于存储不需要训练的数据(如这里的类别权重)非常有用

14.2 分类损失------loss_labels()

14.2.1 loss_labels函数:

计算分类损失,即预测类别和真实标签类别的损失

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}
        if log:
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses
  1. loss_labels函数,接受5个参数,模型输出outputs、真实目标targets、匈牙利匹配的结果indices、图像中边界框的总数num_boxes,以及一个标志log,用于决定是否记录额外的分类错误信息
  2. 确保outputs字典中包含pred_logits键,该键对应的值包含模型预测的分类逻辑值
  3. src_logits ,从outputs中提取预测的分类索引
  4. idx ,调用_get_src_permutation_idx方法,根据匈牙利匹配结果indices计算预测和目标真实之间的排列索引
  5. target_classes_o ,根据匹配结果,从每个目标中选取匹配的标签,然后将这些标签连接成一个一维张量target_classes_o
  6. target_classes ,创建一个填充值为self.num_classes(表示“无对象”类别)的张量target_classes,其形状与预测逻辑值src_logits的前两维相同,这个张量用于存放每个预测位置的目标类别
  7. target_classes[idx],使用计算得到的索引idx将匹配的真实类别target_classes_o填充到target_classes张量中的相应位置
  8. loss_ce ,计算交叉熵损失。首先,将src_logits的维度进行转置以匹配cross_entropy函数的期望输入形状,然后使用target_classes作为真实类别标签,self.empty_weight应用于类别权重
  9. losses,创建一个记录分类损失的字典
  10. 如果log标志为True,则计算并记录额外的分类错误信息
  11. losses[‘class_error’],计算分类准确率,并将其转换为分类错误率,添加到losses字典中。accuracy函数计算匹配的预测和真实标签之间的准确率,然后从100%中减去这个值得到错误率

14.2.2 accuracy函数:

@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

14.3 基数误差------loss_cardinality()

基数误差主要用于评估和记录模型预测的边界框数量与真实边界框数量之间的绝对误差,而不用于训练过程中的梯度计算或模型优化。这个指标可以帮助理解模型在预测边界框数量方面的准确性,尤其是在它可能预测出过多或过少边界框的情况下。尽管它不直接影响模型训练,但对于模型性能的分析调试来说是一个有用的指标。

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses
  1. 这个装饰器表示在执行此函数时不追踪梯度
  2. loss_cardinality函数,接收4个参数:outputs是模型的输出字典,targets是真实目标的列表,indices是匈牙利匹配算法计算得到的匹配索引,num_boxes是每个图像中目标框的数量
  3. pred_logits ,从模型输出中提取预测的类别逻辑值(logits)
  4. device ,获取pred_logits所在的设备(例如CPU或GPU),以确保后续操作在同一设备上进行,避免不必要的数据传输
  5. tgt_lengths ,计算每个目标中标签的数量,并将这些数量创建为一个tensor:tgt_lengths,该tensor在与pred_logits相同的设备上。这个tensor代表了每个图像中真实目标框的数量
  6. card_pred ,计算预测的有分类对象边界框数量。首先,使用argmax(-1)找到每个预测的最可能的类别索引;然后,通过比较这些索引是否不等于最后一个类别(最后一个类别为“无对象”类别),得到一个布尔tensor,其中True表示预测为非空边界框;最后,对布尔张量沿着第一个维度求和,得到每个图像中预测的非空边界框的数量
  7. card_err ,使用L1损失函数计算预测的非空边界框数量card_pred与真实目标框数量tgt_lengths之间的绝对误差,这里将card_pred和tgt_lengths转换为浮点数,因为L1损失函数要求输入为浮点数
  8. losses ,包含了计算得到的基数误差card_err,键为"cardinality_error"的字典
  9. return

14.4 边界框损失------loss_boxes()

计算与边界框相关的损失,包括L1回归损失和GIoU损失,标签是一个字典包含键“boxes”,其包含一个维度为[nb_target_boxes, 4]的张量,目标边界框的格式预期为(center_x, center_y, w, h)

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses
  1. loss_boxes函数,传入4个参数,模型输出outputs、真实边界框标签targets、匈牙利匹配的结果indices、图像中边界框的总数
  2. 确保outputs输出字典中包含’pred_boxes’键
  3. idx ,调用_get_src_permutation_idx方法,根据匈牙利匹配算法得到的indices(匹配索引)计算源(预测)和目标(真实)之间的排列索引idx。这些索引用于从预测和目标中选择匹配的项
  4. src_boxes ,使用排列索引idx从模型输出的预测边界框中选择匹配的预测边界框src_boxes
  5. target_boxes ,通过遍历每个目标字典t和对应的匹配索引(_, i),从目标中选择匹配的边界框,然后使用torch.cat将这些边界框连接成一个连续的张量target_boxes。
  6. loss_bbox ,计算预测边界框src_boxes和目标边界框target_boxes之间的L1损失。reduction='none’参数表示不对损失进行求和或平均,保持损失的原始形状
  7. losses ,用于存储计算得到的损失
  8. losses,将L1损失的总和除以目标框的总数num_boxes,计算平均L1损失,并将其添加到losses字典中
  9. loss_giou ,计算广义交并比(GIoU)损失。首先,将预测边界框和目标边界框从中心坐标格式转换为角点坐标格式,然后计算它们之间的GIoU,GIoU值在0到1之间,1表示完美重合,因此使用1 - GIoU计算损失
  10. losses[‘loss_giou’],将GIoU损失的总和除以目标框的总数num_boxes,计算平均GIoU损失,并将其添加到losses字典中
  11. 返回字典

14.5 掩码损失------loss_masks()

loss_masks函数主要用于实例分割任务,焦点损失(Focal Loss)和Dice损失,针对模型预测的掩码和真实目标掩码之间的不一致,target为标签,字典含键“masks”,其包含一个维度为[nb_target_boxes, h, w]的张量

    def loss_masks(self, outputs, targets, indices, num_boxes):
        assert "pred_masks" in outputs
        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        masks = [t["masks"] for t in targets]
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]
        src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)
        target_masks = target_masks.flatten(1)
        target_masks = target_masks.view(src_masks.shape)
        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses
  1. loss_masks函数,接受模型输出outputs、 targets掩码标签、匈牙利匹配算法计算得到的匹配索引indices、每个图像中目标框的数量num_boxes
  2. 确保模型输出中包含了预测掩码"pred_masks"
  3. src_idx ,使用匈牙利匹配算法得到的索引indices来计算预测的排列索引
  4. tgt_idx ,使用匈牙利匹配算法得到的索引indices来计算目标的排列索引
  5. src_masks ,从模型输出中提取预测掩码
  6. src_masks ,使用源排列索引src_idx选择匹配的预测掩码
  7. masks ,从每个目标字典中提取真实掩码,形成一个列表masks
  8. target_masks, valid,将真实掩码列表转换为嵌套张量,并分解为掩码张量和有效性布尔张量
  9. target_masks ,目标掩码张量转移到与预测掩码相同的设备上
  10. target_masks ,使用目标排列索引tgt_idx选择匹配的目标掩码
  11. src_masks ,使用双线性插值将预测掩码上采样到目标掩码的尺寸。这是为了确保预测掩码和目标掩码具有相同的空间维度,从而可以计算损失
  12. src_masks ,移除插值结果中多余的维度,并将预测掩码展平,准备进行损失计算
  13. target_masks ,将目标掩码也展平
  14. target_masks ,调整其形状以匹配预测掩码的形状
  15. losses ,计算焦点损失和Dice损失
  16. 返回字典

14.6 辅助函数

14.6.1 获取源排列索引 _get_src_permutation_idx

    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx
  1. _get_src_permutation_idx方法用于生成源(预测)索引的排列,以便将预测结果与匹配的目标对齐
  2. batch_idx是一个tensor,包含每个预测索引所在批次的索引,是通过遍历indices(每个元素是一个(src, tgt)索引对)并对每个src索引填充对应的批次号i来创建的
  3. src_idx是一个tensor,直接从indices中提取所有的src索引并将它们连接起来
  4. 函数返回两个tensor:batch_idx和src_idx,它们一起用于从预测tensor中选择与目标匹配的元素

14.6.2 获取目标排列索引 _get_tgt_permutation_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx
  1. 这个方法与_get_src_permutation_idx非常相似,但用于生成目标(真实)索引的排列
  2. 与源索引生成方法类似,batch_idx表示每个目标索引所在的批次索引
  3. tgt_idx则是从indices中提取的所有目标索引
  4. 返回的batch_idx和tgt_idx用于从目标数据中选择与预测匹配的元素。

14.6.3 计算损失 get_loss

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  1. loss_map是一个字典,将损失名称映射到相应的损失计算方法
  2. 方法首先检查请求的损失类型是否存在于loss_map中,如果不存在,将引发断言错误
  3. 如果损失类型有效,将调用对应的损失计算方法,并传入必要的参数

14.7 前向传播

forward方法的主要职责是执行模型的损失计算,其中outputs参数包含了模型的预测输出,而targets参数包含了每个batch中的真实目标信息。outputs应该是一个字典,其中包含模型输出的各种张量,例如预测的类别、边界框或掩码等。targets是一个包含多个字典的列表,每个字典代表一个样本的真实标注信息,如目标类别标签和边界框坐标等。

不同的损失计算可能需要targets中的不同键,例如,边界框损失需要"boxes"键,而分类损失则需要"labels"键。这意味着,为了正确计算损失,需要确保targets中包含了所有必需的信息。此外,根据不同损失的需求,outputs中也应包含相应的预测信息

    def forward(self, outputs, targets):
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
        indices = self.matcher(outputs_without_aux, targets)
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == 'masks':
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        kwargs = {'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)
        return losses
  1. forward函数,传入2个参数
  2. outputs_without_aux ,从模型输出中排除辅助输出(‘aux_outputs’),只保留最后一层的输出用于匹配和损失计算
  3. indices ,matcher是matcher.py中自定义类的实例化,根据最后一层的输出和真实目标计算匹配索引indices
  4. num_boxes ,计算所有目标中标签的总数,即批次中目标框的总数量
  5. num_boxes ,将目标框总数转换为tensor,并放置在模型输出所在的设备上
  6. 如果使用了分布式训练:
  7. 通过all_reduce操作聚合所有进程的目标框总数
  8. num_boxes ,将目标框总数除以 分布式训练中的总进程数,并确保至少为1,以避免除以0的情况
  9. losses ,用于存储计算出的所有损失
  10. 遍历所有需要计算的损失类型
  11. 调用get_loss方法计算每种损失,并更新到losses字典中
  12. 如果模型输出中包含辅助输出:
  13. 对每个辅助输出重复损失计算过程
  14. 遍历每个辅助输出
  15. indices ,对每个辅助输出使用matcher计算匹配索引
  16. 遍历所有需要计算的损失类型
  17. 损失类型为掩码损失:
  18. 由于计算成本过高,选择跳过不计算
  19. 对于标签损失,设置额外参数以禁用日志记录,这通常只在最后一层启用
  20. l_dict ,调用get_loss方法计算当前循环中给定损失类型的损失
  21. l_dict ,为辅助输出的损失键添加后缀,以区分不同层的损失
  22. losses,将计算出的损失更新到总损失字典中
  23. 返回损失字典

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8 损失计算:(SetCriterion类)

你可能感兴趣的:(Transformer实战,transformer,深度学习,人工智能,计算机视觉,DETR,物体检测)