有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8 损失计算:(SetCriterion类)
位置:DETR/models/detr.py/SetCriterion类
这个类专门用来计算DETR的各种损失,主要包含两步:
1、计算真实边界框和模型输出之间的匈牙利匹配
2、对每一对匹配的真实目标/预测(监督类别和边界框)进行监督
DETR需要计算的损失包含:
1、分类损失,80类别
2、回归框的损失,x、y、w、h,以及GIU
假如标签有5个框,100个框里面就需要挑选出5个来进行计算损失,可是100个里面应该挑选那个呢?使用匈牙利匹配算法,首先将100个框中首先预测的物体要跟标签的5个框中的预测物体是对应的那些框找出来,然后找出对应损失最小的5个框,这就是匈牙利算法
class SetCriterion(nn.Module):
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
self.losses = losses
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer('empty_weight', empty_weight)
计算分类损失,即预测类别和真实标签类别的损失
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {'loss_ce': loss_ce}
if log:
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
基数误差主要用于评估和记录模型预测的边界框数量与真实边界框数量之间的绝对误差,而不用于训练过程中的梯度计算或模型优化。这个指标可以帮助理解模型在预测边界框数量方面的准确性,尤其是在它可能预测出过多或过少边界框的情况下。尽管它不直接影响模型训练,但对于模型性能的分析和调试来说是一个有用的指标。
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses
计算与边界框相关的损失,包括L1回归损失和GIoU损失,标签是一个字典包含键“boxes”,其包含一个维度为[nb_target_boxes, 4]的张量,目标边界框的格式预期为(center_x, center_y, w, h)
def loss_boxes(self, outputs, targets, indices, num_boxes):
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
loss_masks函数主要用于实例分割任务,焦点损失(Focal Loss)和Dice损失,针对模型预测的掩码和真实目标掩码之间的不一致,target为标签,字典含键“masks”,其包含一个维度为[nb_target_boxes, h, w]的张量
def loss_masks(self, outputs, targets, indices, num_boxes):
assert "pred_masks" in outputs
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
masks = [t["masks"] for t in targets]
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
mode="bilinear", align_corners=False)
src_masks = src_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes,
'masks': self.loss_masks
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
forward
方法的主要职责是执行模型的损失计算,其中outputs
参数包含了模型的预测输出,而targets
参数包含了每个batch中的真实目标信息。outputs
应该是一个字典,其中包含模型输出的各种张量,例如预测的类别、边界框或掩码等。targets
是一个包含多个字典的列表,每个字典代表一个样本的真实标注信息,如目标类别标签和边界框坐标等。
不同的损失计算可能需要targets
中的不同键,例如,边界框损失需要"boxes"
键,而分类损失则需要"labels"
键。这意味着,为了正确计算损失,需要确保targets
中包含了所有必需的信息。此外,根据不同损失的需求,outputs
中也应包含相应的预测信息
def forward(self, outputs, targets):
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
indices = self.matcher(outputs_without_aux, targets)
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == 'masks':
continue
kwargs = {}
if loss == 'labels':
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8 损失计算:(SetCriterion类)