NanoDet关键代码详解(首发)

更新中。。。

task.py

    def training_step(self, batch, batch_idx):
        preds, loss, loss_states = self.model.forward_train(batch)

        # log train losses
        if self.log_style == 'Lightning':
            self.log('lr', self.optimizers().param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=True)
            for k, v in loss_states.items():
                self.log('Train/'+k, v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        elif self.log_style == 'NanoDet' and self.global_step % self.cfg.log.interval == 0:
            lr = self.optimizers().param_groups[0]['lr']
            log_msg = 'Train|Epoch{}/{}|Iter{}({})| lr:{:.2e}| '.format(self.current_epoch+1,
                self.cfg.schedule.total_epochs, self.global_step, batch_idx, lr)
            self.scalar_summary('Train_loss/lr', 'Train', lr, self.global_step)
            for l in loss_states:
                log_msg += '{}:{:.4f}| '.format(l, loss_states[l].mean().item())
                self.scalar_summary('Train_loss/' + l, 'Train', loss_states[l].mean().item(), self.global_step)
            self.info(log_msg)

        return loss

one_stage_detector.py

    def forward_train(self, gt_meta):
        # 由图片计算出推理出的两个feature map
        preds = self(gt_meta['img'])
        # 由feature map和GT计算出loss
        loss, loss_states = self.head.loss(preds, gt_meta)

        return preds, loss, loss_states

gfl_head.py

    def loss(self, preds, gt_meta):
        cls_scores, bbox_preds = preds
        batch_size = cls_scores[0].shape[0]
        device = cls_scores[0].device
        gt_bboxes = gt_meta['gt_bboxes']
        gt_labels = gt_meta['gt_labels']
        gt_bboxes_ignore = None

        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]

        cls_reg_targets = self.target_assign(batch_size, featmap_sizes, gt_bboxes,
                                             gt_bboxes_ignore, gt_labels, device=device)
        if cls_reg_targets is None:
            return None

        (grid_cells_list, labels_list, label_weights_list, bbox_targets_list,
         bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets

        num_total_samples = reduce_mean(
            torch.tensor(num_total_pos).to(device)).item()
        num_total_samples = max(num_total_samples, 1.0)

        # 分层计算对应的loss
        losses_qfl, losses_bbox, losses_dfl, \
        avg_factor = multi_apply(
            self.loss_single,
            grid_cells_list,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            self.strides,
            num_total_samples=num_total_samples)

        avg_factor = sum(avg_factor)
        avg_factor = reduce_mean(avg_factor).item()
        if avg_factor <= 0:
            loss_qfl = torch.tensor(0, dtype=torch.float32, requires_grad=True).to(device)
            loss_bbox = torch.tensor(0, dtype=torch.float32, requires_grad=True).to(device)
            loss_dfl = torch.tensor(0, dtype=torch.float32, requires_grad=True).to(device)
        else:
            #对三个特征层的loss求和
            losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
            losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl))

            loss_qfl = sum(losses_qfl)
            loss_bbox = sum(losses_bbox)
            loss_dfl = sum(losses_dfl)

        loss = loss_qfl + loss_bbox + loss_dfl
        loss_states = dict(
            loss_qfl=loss_qfl,
            loss_bbox=loss_bbox,
            loss_dfl=loss_dfl)

        return loss, loss_states

 

    def target_assign(self,
                      batch_size,
                      featmap_sizes,
                      gt_bboxes_list,
                      gt_bboxes_ignore_list,
                      gt_labels_list,
                      device):
        """
        Assign target for a batch of images.
        :param batch_size: num of images in one batch
        :param featmap_sizes: A list of all grid cell boxes in all image
        :param gt_bboxes_list: A list of ground truth boxes in all image
        :param gt_bboxes_ignore_list: A list of all ignored boxes in all image
        :param gt_labels_list: A list of all ground truth label in all image
        :param device: pytorch device
        :return: Assign results of all images.
        """
        # get grid cells of one image
        multi_level_grid_cells = [
            self.get_grid_cells(featmap_sizes[i],
                                self.grid_cell_scale,
                                stride,
                                dtype=torch.float32,
                                device=device) for i, stride in enumerate(self.strides)
        ]
        mlvl_grid_cells_list = [multi_level_grid_cells for i in range(batch_size)]

        # pixel cell number of multi-level feature maps
        num_level_cells = [grid_cells.size(0) for grid_cells in mlvl_grid_cells_list[0]]
        num_level_cells_list = [num_level_cells] * batch_size
        # concat all level cells and to a single tensor
        for i in range(batch_size):
            mlvl_grid_cells_list[i] = torch.cat(mlvl_grid_cells_list[i])
        # compute targets for each image
        if gt_bboxes_ignore_list is None:
            gt_bboxes_ignore_list = [None for _ in range(batch_size)]
        if gt_labels_list is None:
            gt_labels_list = [None for _ in range(batch_size)]
        # target assign on all images, get list of tensors
        # list length = batch size
        # tensor first dim = num of all grid cell
        (all_grid_cells, all_labels, all_label_weights, all_bbox_targets,
         all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
            self.target_assign_single_img,
            mlvl_grid_cells_list,
            num_level_cells_list,
            gt_bboxes_list,
            gt_bboxes_ignore_list,
            gt_labels_list)
        # no valid cells
        if any([labels is None for labels in all_labels]):
            return None
        # sampled cells of all images
        num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
        num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
        # merge list of targets tensors into one batch then split to multi levels
        mlvl_grid_cells = images_to_levels(all_grid_cells, num_level_cells)
        mlvl_labels = images_to_levels(all_labels, num_level_cells)
        mlvl_label_weights = images_to_levels(all_label_weights, num_level_cells)
        mlvl_bbox_targets = images_to_levels(all_bbox_targets, num_level_cells)
        mlvl_bbox_weights = images_to_levels(all_bbox_weights, num_level_cells)
        return (mlvl_grid_cells, mlvl_labels, mlvl_label_weights,
                mlvl_bbox_targets, mlvl_bbox_weights, num_total_pos,
                num_total_neg)

 

    def target_assign_single_img(self,
                                 grid_cells,
                                 num_level_cells,
                                 gt_bboxes,
                                 gt_bboxes_ignore,
                                 gt_labels):
        """
        Using ATSS Assigner to assign target on one image.
        :param grid_cells: Grid cell boxes of all pixels on feature map
        :param num_level_cells: numbers of grid cells on each level's feature map
        :param gt_bboxes: Ground truth boxes
        :param gt_bboxes_ignore: Ground truths which are ignored
        :param gt_labels: Ground truth labels
        :return: Assign results of a single image
        """
        device = grid_cells.device
        gt_bboxes = torch.from_numpy(gt_bboxes).to(device)
        gt_labels = torch.from_numpy(gt_labels).to(device)
        
        # 利用ATSS算法筛选出用来预测gt的grid cell
        assign_result = self.assigner.assign(grid_cells, num_level_cells,
                                             gt_bboxes, gt_bboxes_ignore,
                                             gt_labels)
        
        # 得到grid cell的pos_index和neg_index,及其对应的pos的bbox和label index
        pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = \
            self.sample(assign_result, gt_bboxes)
        
        # 构建空的bbox_targe和label,将pos_bbox_targets和gt_labels填充到pos_inds对于的位置
        num_cells = grid_cells.shape[0]
        bbox_targets = torch.zeros_like(grid_cells)
        bbox_weights = torch.zeros_like(grid_cells)
        labels = grid_cells.new_full((num_cells,),
                                     self.num_classes,
                                     dtype=torch.long)
        label_weights = grid_cells.new_zeros(num_cells, dtype=torch.float)

        if len(pos_inds) > 0:
            pos_bbox_targets = pos_gt_bboxes
            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0
            if gt_labels is None:
                # Only rpn gives gt_labels as None
                # Foreground is the first class
                labels[pos_inds] = 0
            else:
                labels[pos_inds] = gt_labels[pos_assigned_gt_inds]

            label_weights[pos_inds] = 1.0
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        # grid_cells,对原图像划分为N*N的grid,其数量和feature map的size相同,但是每个grid_cell的长款为strid*scale,代码中第一层的大小为5*8,其中stride=8,scale=5
        # labels, grid_cells中pos_inds对应的位置赋值gt_labels
        # label_weights,
        # bbox_targets, grid_cells中pos_inds对应的位置赋值为gt的bbox
        # bbox_weights,
        # pos_inds,
        # neg_inds
        return (grid_cells, labels, label_weights, bbox_targets, bbox_weights,
                pos_inds, neg_inds)

 

    # grid_cells, [100,4], 对原始图像划分为若干grid
    # cls_score, [1,80,10,10],推理输出的class feature
    # bbox_pred, [1,32,10,10],推理输出的bbox feature
    # labels, [100], pos_index位置对应gt label标签,其他无效
    # label_weights, [100]
    # bbox_targets, [100,4], pos_index位置对应gt bbox, 其他无效
    # stride, 数值为32
    # num_total_samples, 数值为6, grid_cells中有6个用来负责gt的预测

    def loss_single(self, grid_cells, cls_score, bbox_pred, labels,
                    label_weights, bbox_targets, stride, num_total_samples):

        grid_cells = grid_cells.reshape(-1, 4)
        cls_score = cls_score.permute(0, 2, 3,
                                      1).reshape(-1, self.cls_out_channels)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1))
        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = torch.nonzero((labels >= 0)
                                 & (labels < bg_class_ind), as_tuple=False).squeeze(1)

        score = label_weights.new_zeros(labels.shape)

        if len(pos_inds) > 0:
            # 由pos_inds 筛选出gt bbox,predicted bbox和gird cell
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds]  # (n, 4 * (reg_max + 1))
            pos_grid_cells = grid_cells[pos_inds]
            # 计算grid cell的中心的,并且将其尺寸归一化到当前layer
            pos_grid_cell_centers = self.grid_cells_to_center(pos_grid_cells) / stride

            # 对推理输出的class score采用sigmoid变换
            weight_targets = cls_score.detach().sigmoid()
            # todo
            weight_targets = weight_targets.max(dim=1)[0][pos_inds]
            # 推理结果计算出左、上、右、下各个边的绝对距离
            pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)
            # 以pos_grid_cell_centers为参考点,计算出推理出来的bbox
            pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers,
                                                 pos_bbox_pred_corners)

            # 将gt bbox归一化到当前layer
            pos_decode_bbox_targets = pos_bbox_targets / stride

            # 计算预测bbox和gt bbox的IoU, 将此保存到score中pos_inds所在的位置
            score[pos_inds] = bbox_overlaps(
                pos_decode_bbox_pred.detach(),
                pos_decode_bbox_targets,
                is_aligned=True)

            # 将推理的结果reshape为[24,8],其中24=6*4(4个边),8为距离的离散分布
            pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
            # 以pos_grid_cell_centers为参考点,计算出gt的bbox,并且reshape为24
            target_corners = bbox2distance(pos_grid_cell_centers,
                                           pos_decode_bbox_targets,
                                           self.reg_max).reshape(-1)

            # 输入推理和gt的bbox,计算GIoU loss,关于输入weight,todo
            # regression loss
            loss_bbox = self.loss_bbox(
                pos_decode_bbox_pred,
                pos_decode_bbox_targets,
                weight=weight_targets,
                avg_factor=1.0)


            # 输入推理的原始feature和gt的四个distance,计算Distribution focal loss
            # dfl loss
            loss_dfl = self.loss_dfl(
                pred_corners,
                target_corners,
                weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
                avg_factor=4.0)
        else:
            loss_bbox = bbox_pred.sum() * 0
            loss_dfl = bbox_pred.sum() * 0
            weight_targets = torch.tensor(0).to(cls_score.device)


        # 输入推理出的cls_score [100,80],gt的lables [100], 推理和gt的IoU score [100]
        # 计算qfl loss
        loss_qfl = self.loss_qfl(
            cls_score, (labels, score),
            weight=label_weights,
            avg_factor=num_total_samples)

        return loss_qfl, loss_bbox, loss_dfl, weight_targets.sum()

 


 

 

你可能感兴趣的:(ADAS)