YOLOX-目标检测算法(代码解读)

文章目录

  • 前言
  • 一、测试
    • 1. line196:构建了一个predictor类,将测试图片输入
      • 1. 1 然后进入predictor的inference过程:
        • 1. 1.1. 进入self.model(yolox.py)
        • 1. 1.2. 后处理:postprocess(outputs = self.model(img) 之后)
    • 2. 可视化与结果保存
      • 2.1 可视化展开(demo.py line170)
        • 2.2 可视化函数(yolox.utils/vis)
  • 二、训练
    • 3.self.head (get_losses)
      • 3.1 self.get_assignments
        • 3.1.1 self.get_in_boxes_info
        • 3.1.2 self.dynamic_k_matching
    • 4.梯度回传


前言

YOLOX简洁且高效,分享具体实现过程。部分代码可以迁移,很具有参考价值。


一、测试

测试比较简单,首先看demo.py。

-运行需要指定三个参数:
–path:测试图片路径
–exp_file:指定使用模型配置文件,如default/yolox_m.py
–ckpt:预训练权重,如yolox_m.pth

1. line196:构建了一个predictor类,将测试图片输入

outputs, img_info = predictor.inference(image_name)                       # output:(147:x1,y1,x2,y2,conf,conf,class
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)

1. 1 然后进入predictor的inference过程:

img = cv2.imread(img)
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])  
# 这里是对原图做比例缩放,至640*640

img, _ = self.preproc(img, None, self.test_size)  # 转为(3640640)
with torch.no_grad():
     outputs = self.model(img)                    #  ([1, 8400, 85])8400 = 80*80 +40*40 +20*2085 = 80+4+1
outputs = postprocess(
            outputs, self.num_classes, self.confthre,
            self.nmsthre, class_agnostic=True
        )

1. 1.1. 进入self.model(yolox.py)

fpn_outs = self.backbone(x)
# (128, 80, 80]) (256, 40, 40) (512, 20, 20) 下采样的三个特征图
outputs = self.head(fpn_outs)
  1. 1.1.1. 下面展开看self.head(models/yolo_head.py)
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
       zip(self.cls_convs, self.reg_convs, self.strides, xin)
  ):                                                     # 循环3次,每次对一个特征图进行分类和回归
            x = self.stems[k](x)                         # 将特征图维度变换至128,如特征1(1,128,80,80)
            cls_x = x
            reg_x = x

            cls_feat = cls_conv(cls_x)                   # 这里是解藕头,连续两个conv(128,128,3,1)+bn+SiLU
            cls_output = self.cls_preds[k](cls_feat)     # Conv2d(128, 20),分类

            reg_feat = reg_conv(reg_x)                   # 解藕头,同上
            reg_output = self.reg_preds[k](reg_feat)     # Conv2d(128, 4),回归
            obj_output = self.obj_preds[k](reg_feat)     # Conv2d(128, 1),目标预测


            output = torch.cat(
                [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
            )                                            # (1,25,80,80)

            outputs.append(output)                       # (1,25,80,80) (1,25,40,40) (1,25,20,20)
self.hw = [x.shape[-2:] for x in outputs]                # torch.Size(80, 80)(40, 40), (20, 20)
outputs = torch.cat(
     [x.flatten(start_dim=2) for x in outputs], dim=2
    ).permute(0, 2, 1)                                   # ([1, 8400, 25])
if self.decode_in_inference:                             # True
     return self.decode_outputs(outputs, dtype=xin[0].type())
else:
     return outputs
  1. 1.1.1.1. 下面看decode_outputs函数(yolo_head.py)
    这个函数主要用来做解码。yolox预测坐标(x,y,w,h)为相对偏移量。anchor-free但还是有一个anchor作为基准,即在特征图上均匀采样,下采样倍数为anchor宽度。在此基础上做解码,得到最终输出。
def decode_outputs(self, outputs, dtype):
    grids = []
    strides = []
    for (hsize, wsize), stride in zip(self.hw, self.strides):   # 80,40,20,对应下采样[8, 16, 32]
        yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])   
        # 以(8080)特征图为例,生成两个(80,80)坐标点

        grid = torch.stack((xv, yv), 2).view(1, -1, 2)          # ([1, 6400, 2])
        grids.append(grid)
        shape = grid.shape[:2]                                  # ([1, 6400])
        strides.append(torch.full((*shape, 1), stride))         # (1,6400,1)*[8] (1,1600,1)*[16] (1,400,1)*[32]  

    grids = torch.cat(grids, dim=1).type(dtype)
    strides = torch.cat(strides, dim=1).type(dtype)

    outputs[..., :2] = (outputs[..., :2] + grids) * strides     # (预测x、y+anchor中心点坐标)*下采样倍数
    outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides  #  (预测w、h)*下采样倍数
    return outputs                                              #  ([1, 8400, 85])8400 = 80*80 +40*40 +20*2085 = 80+4+1

1. 1.2. 后处理:postprocess(outputs = self.model(img) 之后)

outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre, class_agnostic=True):



def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
    box_corner = prediction.new(prediction.shape)
    ## 转为左上角与右下角坐标:x1 y1 x2 y2
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for i, image_pred in enumerate(prediction):                                                    # image_pred:(8400, 85)

        # If none are remaining => process next image
        if not image_pred.size(0):
            continue
        # Get score and class with highest confidence
        class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)     # 类别分数*置信度,用0.3筛选
 
        conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)             # (8400, 7)
        detections = detections[conf_mask]                                                         #  (93, 7) 根据0.3置信度筛选后

        if class_agnostic:
            nms_out_index = torchvision.ops.nms(
                detections[:, :4],
                detections[:, 4] * detections[:, 5],
                nms_thre,
            )                                                                                # NMS(根据分数和位置):返回剩余目标的index
        else:
            nms_out_index = torchvision.ops.batched_nms(
                detections[:, :4],
                detections[:, 4] * detections[:, 5],
                detections[:, 6],
                nms_thre,
            )                                                                               # 未执行

        detections = detections[nms_out_index]                                              # (147if output[i] is None:
            output[i] = detections
        else:
            output[i] = torch.cat((output[i], detections))

    return output

2. 可视化与结果保存

outputs, img_info = predictor.inference(image_name)
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
if save_result:
    save_folder = os.path.join(
        vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
    )
    os.makedirs(save_folder, exist_ok=True)
    save_file_name = os.path.join(save_folder, os.path.basename(image_name))
    logger.info("Saving detection result in {}".format(save_file_name))
    cv2.imwrite(save_file_name, result_image)
    ch = cv2.waitKey(0)

2.1 可视化展开(demo.py line170)

    def visual(self, output, img_info, cls_conf=0.35):
        ratio = img_info["ratio"]              # 缩放比例:0.45
        img = img_info["raw_img"]              # (1050, 1400, 3)
        if output is None:
            return img
        output = output.cpu()

        bboxes = output[:, 0:4]

        # preprocessing: resize
        bboxes /= ratio

        cls = output[:, 6]
        scores = output[:, 4] * output[:, 5]

        vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
        return vis_res

2.2 可视化函数(yolox.utils/vis)

def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):

    for i in range(len(boxes)):
        box = boxes[i]
        cls_id = int(cls_ids[i])
        score = scores[i]
        if score < conf:
            continue
        x0 = int(box[0])
        y0 = int(box[1])
        x1 = int(box[2])
        y1 = int(box[3])

        color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
        text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
        txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
        font = cv2.FONT_HERSHEY_SIMPLEX

        txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
        cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)

        txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
        cv2.rectangle(
            img,
            (x0, y0 + 1),
            (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
            txt_bk_color,
            -1
        )
        cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)

    return img

二、训练

训练阶段数据格式:在datadets/VOCdevkit/VOC2007/文件夹中存放三个文件夹,分别为:JPEGImages(若干张jpg图像)Annotations(对应的若干个xml标注)ImageSets文件夹。
训练从train.py第line 110进入trainer.train()

yolox.py line30:

fpn_outs = self.backbone(x)

      if self.training:
          assert targets is not None
          loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
              fpn_outs, targets, x
          )
          outputs = {
              "total_loss": loss,
              "iou_loss": iou_loss,
              "l1_loss": l1_loss,
              "conf_loss": conf_loss,
              "cls_loss": cls_loss,
              "num_fg": num_fg,
          }
      else:
          outputs = self.head(fpn_outs)                 # Iou损失、类别与置信度损失

      return outputs

3.self.head (get_losses)

主要函数是 self.get_assignments,用来分配正标签,下面会给出具体分析
以及其中的self.dynamic_k_matching函数,动态获得k个正样本

class YOLOXHead(nn.Module):
def get_losses(self,imgs, x_shifts, y_shifts,  expanded_strides, labels, outputs,
      origin_preds, dtype):
      bbox_preds = outputs[:, :, :4]                    # [bs, n_anchors, 4]:([8, 8400, 4])
      obj_preds = outputs[:, :, 4].unsqueeze(-1)        # ([8, 8400, 1])
      cls_preds = outputs[:, :, 5:]                     # ([8, 8400, 20])

      # calculate targets
      nlabel = (labels.sum(dim=2) > 0).sum(dim=1)       # gt_num:[ 5,  6, 21,  2,  5,  2,  2,  6]

      total_num_anchors = outputs.shape[1]                                                                              # 8400
      x_shifts = torch.cat(x_shifts, 1)                 # [1, n_anchors_all]                                                      x_shifts[0]:(1, 6400)  x_shifts[1]:(1, 1600)  x_shifts[2]:(1, 400) [0,1,2,...19,0,1,2...]
      y_shifts = torch.cat(y_shifts, 1)                 # [1, n_anchors_all]                                                      ([1, 8400])
      expanded_strides = torch.cat(expanded_strides, 1)                                                       # (1,8400):  6400*[8,8,8...]   1600*[16,16,16...]    400*[32,32,32,...]
      if self.use_l1:
          origin_preds = torch.cat(origin_preds, 1)

      cls_targets = []
      reg_targets = []
      l1_targets = []
      obj_targets = []
      fg_masks = []

      num_fg = 0.0
      num_gts = 0.0

      for batch_idx in range(outputs.shape[0]):                       # batchsize
          num_gt = int(nlabel[batch_idx])
          num_gts += num_gt                                           # 5
          if num_gt == 0:
              cls_target = outputs.new_zeros((0, self.num_classes))
              reg_target = outputs.new_zeros((0, 4))
              l1_target = outputs.new_zeros((0, 4))
              obj_target = outputs.new_zeros((total_num_anchors, 1))
              fg_mask = outputs.new_zeros(total_num_anchors).bool()
          else:
              gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]   # (8,4)
              gt_classes = labels[batch_idx, :num_gt, 0]              # (8) gt_num
              bboxes_preds_per_image = bbox_preds[batch_idx]          # (8400,4)

              try:   
                  ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img,
                  ) = self.get_assignments( batch_idx,  num_gt, total_num_anchors,
                        gt_bboxes_per_image, gt_classes, bboxes_preds_per_image,
                        expanded_strides, x_shifts, y_shifts,  cls_preds, bbox_preds,
                        obj_preds,  labels,  imgs)
          # 以上函数:分配正负样本。返回值可查看 3.1节self.get_assignments 最后结果
          
          torch.cuda.empty_cache()
          num_fg += num_fg_img                                                                           # 34

          cls_target = F.one_hot(
              gt_matched_classes.to(torch.int64), self.num_classes
          ) * pred_ious_this_matching.unsqueeze(-1)                   # (34)  --> ( 34,20 ) *iou_score
          obj_target = fg_mask.unsqueeze(-1)                          # ( 8400,1 )  :34*True
          reg_target = gt_bboxes_per_image[matched_gt_inds]           # ( 34,4 )


      cls_targets.append(cls_target)
      reg_targets.append(reg_target)
      obj_targets.append(obj_target.to(dtype))
      fg_masks.append(fg_mask)
      if self.use_l1:                                                    # False
          l1_targets.append(l1_target)

  cls_targets = torch.cat(cls_targets, 0)                # ( 385,20 )
  reg_targets = torch.cat(reg_targets, 0)                # ( 385,4 )
  obj_targets = torch.cat(obj_targets, 0)                # ( 67200,1 )       8400*8 = 67200 
  fg_masks = torch.cat(fg_masks, 0)                      # ( 67200 )
  if self.use_l1:
            l1_targets = torch.cat(l1_targets, 0)

  num_fg = max(num_fg, 1)
  loss_iou = (
      self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
  ).sum() / num_fg
  loss_obj = (
      self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
  ).sum() / num_fg
  loss_cls = (
      self.bcewithlog_loss(
          cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
      )
  ).sum() / num_fg
  if self.use_l1:
      loss_l1 = (
          self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)
      ).sum() / num_fg
  else:
       loss_l1 = 0.0

  reg_weight = 5.0
  loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1

  return (
      loss,
      reg_weight * loss_iou,
      loss_obj,
      loss_cls,
      loss_l1,
      num_fg / max(num_gts, 1),
  )

3.1 self.get_assignments

这里是把标签gt分配到三张特征图上(共8400个点),并作出正负样本分类。

def get_assignments( self, batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image,
        gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
        cls_preds,  bbox_preds,  obj_preds,  labels,  imgs,  mode="gpu"):

      fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
          gt_bboxes_per_image,  expanded_strides, x_shifts,  
          y_shifts, total_num_anchors, num_gt)                                                                                                                                    # (8400)  : 3473*[True]     # (5, 3473)  :325*[True]

      bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]           # ([8400, 4])  ---> ([3473, 4])
      cls_preds_ = cls_preds[batch_idx][fg_mask]                                                   # ([3473, 20])
      obj_preds_ = obj_preds[batch_idx][fg_mask]                                                 # ([3473, 1])
      num_in_boxes_anchor = bboxes_preds_per_image.shape[0]              # 3473



      pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False   # (5,4) & (3473,4) --> (5, 3473)

      gt_cls_per_image = (
         F.one_hot(gt_classes.to(torch.int64), self.num_classes)
         .float() .unsqueeze(1)  .repeat(1, num_in_boxes_anchor, 1))          # (5,1) --> (5,20) --> (5,3473,20)  
      pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)                # (5, 3473)

      with torch.cuda.amp.autocast(enabled=False):
          cls_preds_ = (
              cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
              * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
          )                                                             # ( 3473, 20 ) --> sigmoid --> ( 5, 3473, 20 )
          pair_wise_cls_loss = F.binary_cross_entropy(
              cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
            ).sum(-1)                                                   # ( 5, 3473, 20 ) & ( 5, 3473, 20 )  ---> ( 5,3473 )
      del cls_preds_

      cost = (
          pair_wise_cls_loss
          + 3.0 * pair_wise_ious_loss
          + 100000.0 * (~is_in_boxes_and_center)
      )                                                                 #  ( 5, 3473 )

        (
            num_fg,
            gt_matched_classes,
            pred_ious_this_matching,
            matched_gt_inds,
        ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

        return (
            gt_matched_classes,                             #(3434个正样本的类别
            fg_mask,                                        #(8400)中有34个True
            pred_ious_this_matching,                        #(3434个正样本的IOU
            matched_gt_inds,                                # (34) 34个正样本,跟第几个gt更匹配
            num_fg,
        )

3.1.1 self.get_in_boxes_info

对预测的8400个目标作初步筛选
根据anchor中心点与gt左上右下的偏移值,筛选出偏移大于0的结果(计算b_l, b_t, b_r, b_b的位置)(c_l, c_t, c_r, c_b也是同理)

    def get_in_boxes_info(
        self, gt_bboxes_per_image, expanded_strides,  x_shifts,
        y_shifts, total_num_anchors, num_gt):
        expanded_strides_per_image = expanded_strides[0]                      # (8400)
        x_shifts_per_image = x_shifts[0] * expanded_strides_per_image         # (8400)     [0,1,2...79,...0,1,2,...39,0,1,2,...19]*stride
        y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
        x_centers_per_image = (
            (x_shifts_per_image + 0.5 * expanded_strides_per_image)
            .unsqueeze(0)
            .repeat(num_gt, 1)                                                                                                  # (5,8400) 8400个中心点坐标(640*640图像上的绝对值)
        )  # [n_anchor] -> [n_gt, n_anchor]
        y_centers_per_image = (
            (y_shifts_per_image + 0.5 * expanded_strides_per_image)
            .unsqueeze(0)
            .repeat(num_gt, 1)
        )

        gt_bboxes_per_image_l = (
            (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )                                                                                                                                           #  ([5, 8400])   x1
        gt_bboxes_per_image_r = (
            (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )                                                                                                                                         #  ([5, 8400])   x2
        gt_bboxes_per_image_t = (
            (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )                                                                                                                                         #  ([5, 8400])   y1
        gt_bboxes_per_image_b = (
            (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )                                                                                                                                         #  ([5, 8400])   y2

        b_l = x_centers_per_image - gt_bboxes_per_image_l                             #  ([5, 8400])
        b_r = gt_bboxes_per_image_r - x_centers_per_image
        b_t = y_centers_per_image - gt_bboxes_per_image_t
        b_b = gt_bboxes_per_image_b - y_centers_per_image
        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)                                       # ([5, 8400, 4])    gt与anchor中心点的四个偏移值 

        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0                               # ([5, 8400])  
        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
        # in fixed center

        center_radius = 2.5

        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
            1, total_num_anchors                                                                                       # (5,1)  ->(5.8400)
        ) - center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) + center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) - center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) + center_radius * expanded_strides_per_image.unsqueeze(0)

        c_l = x_centers_per_image - gt_bboxes_per_image_l
        c_r = gt_bboxes_per_image_r - x_centers_per_image
        c_t = y_centers_per_image - gt_bboxes_per_image_t
        c_b = gt_bboxes_per_image_b - y_centers_per_image
        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)                                        # ([5, 8400, 4])
        is_in_centers = center_deltas.min(dim=-1).values > 0.0
        is_in_centers_all = is_in_centers.sum(dim=0) > 0

        # in boxes and in centers
        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all                        # (8400)      : 3473*[True]

        is_in_boxes_and_center = (
            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]     # ([5, 3473])  :325*[True]
        )
        return is_in_boxes_anchor, is_in_boxes_and_center

3.1.2 self.dynamic_k_matching

根据iou动态选择k个样本
例如:给5个gt分配了34个样本,并返回这34个样本的最大iou分数(pred_ious_this_matching)

    def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
        # Dynamic K
        # ---------------------------------------------------------------
        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)                          # ([5, 3473])

        ious_in_boxes_matrix = pair_wise_ious
        n_candidate_k = min(10, ious_in_boxes_matrix.size(1))                                    # 10
        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)                # ( 5, 10 )
        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
        dynamic_ks = dynamic_ks.tolist()                                                                                 # [3, 7, 9, 9, 6]
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(
                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
            )                                                                                                                                                # ([3473])中取前3个   pos_idx:  [ 3236, 3235, 3237 ]
            matching_matrix[gt_idx][pos_idx] = 1                                                                    # 全0矩阵matching_matrix([5, 3473])的每行(每个gt)中,分别有 [3, 7, 9, 9, 6]个是1

        del topk_ious, dynamic_ks, pos_idx

        anchor_matching_gt = matching_matrix.sum(0)                                                   # ( 3473 )
        if (anchor_matching_gt > 1).sum() > 0:
            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
            matching_matrix[:, anchor_matching_gt > 1] *= 0
            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
        fg_mask_inboxes = matching_matrix.sum(0) > 0                                                    # ( 3473 )    34*[ True ]
        num_fg = fg_mask_inboxes.sum().item()                                                                   # 34                    

        fg_mask[fg_mask.clone()] = fg_mask_inboxes

        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)       # ([5, 3473]) --> ([5, 34]).argmax --> (34)
               #   [4, 4, 2, 4, 4, 4, 3, 3, 3, 3, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 3, 2, 4, 3, 3, 2, 2, 1, 3, 1, 3, 1]
        gt_matched_classes = gt_classes[matched_gt_inds]
               #  ( 34 ): [ 14., 14., 14., 14., 14., 14., 14., 14., 14., 14.,  8.,  8.,  8.,  8., 11., 11., 11., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14.,  8., 14.,  8., 14.,  8. ] 
        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
            fg_mask_inboxes
        ]                                             # ( 34 ) scoers
        return 
        num_fg,                                       # 34
        gt_matched_classes,                           #(3434个正样本的类别
        pred_ious_this_matching,                      #(3434个正样本的IOU
        matched_gt_inds                               # (34) 34个正样本,跟第几个gt更匹配
        fg_mask                                       # (8400)中有34个True

4.梯度回传

 outputs = self.model(inps, targets)

 loss = outputs["total_loss"]

 self.optimizer.zero_grad()
 self.scaler.scale(loss).backward()
 self.scaler.step(self.optimizer)
 self.scaler.update()

你可能感兴趣的:(深度学习,pytorch,计算机视觉)