【庖丁解牛】从零实现FCOS(二):ground truth分配与loss计算

文章目录

  • Anchor free?Anchor base?
  • FCOS的ground truth分配
  • loss计算
  • 完整loss代码

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

Anchor free?Anchor base?

首先要明确的是,FCOS确实没有像RetinaNet那样使用了显式的Anchor(先验框)。FCOS把每一级FPN level的feature map上的每一个点作为一个样本,然后,根据样本在标注框内还是标注框外决定该样本是正样本还是负样本(注意FCOS中没有被忽略的样本)。从这一点上来说,FCOS确实是Anchor free的。但是,在FCOS进行ground trurh分配和测试计算时仍然要使用feature map上每个点倒推到输入图片上的(x,y)坐标,从这一点上来说,FCOS并不是完全free的,更准确地来说,FCOS是一个"point based"目标检测器。我们可以把FCOS看成是feature map上每个点只有一个隐式Anchor的目标检测器。

2020年新发布的DETR目标检测器(https://arxiv.org/pdf/2005.12872.pdf)把目标检测任务检测看成集合预测问题,使用了Transformer来预测box集合,完全不需要使用NMS和Anchor/Point的先验坐标,使得检测器真正做到了"free",感兴趣的同学可以自行了解。

FCOS的ground truth分配

对于一张输入图片上标注的多个框,首先把FPN上每一级FPN的feature map上的所有点都做判断,如果某个点在所有的标注框之外,那么这个点就作为负样本。此时,剩下的点中有些点可能同时在多个标注框内。然后取每个点对每个标注框的l,t,r,b(该点距离框左、上、右、下的距离)中的最大值,根据下面的值域范围,当最大值落在哪个范围内,就把该框分配给这个范围对应的FPN level的feature map上的对应点。

# 从左到右为分配给P3、P4、P5、P6、P7的值域范围
INF=100000000
mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]]

经过上面一步以后,绝大部分点都会只分配给一个框。但是仍然有些点会同时在两个框内(当有两个标注框的大小差不多的时候)。对于这些点,我们计算其与重叠框的面积,然后总是把这些点分配给面积最小的标注框。在下面的实现代码中,对于这部分样本我使用了矩阵计算的形式进行标签分配。虽然每张图上正样本一般只有几十到两三百左右,但是如果对这部分正样本使用for循环来分配标签,训练速度会变得非常慢,这一点需要注意。
对于分类标签,以0为负样本,1到80为80个正类;l,t,r,b和centerness标签完全按照FCOS论文中公式计算,没有修改。

ground truth分配代码实现如下:

    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)

        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])

            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)

        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)

        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]

            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]

                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag

                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag

                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)

                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]

                    del candidates

                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)

                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)

                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1

                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)

                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)

                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]

                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1

                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))

            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)

        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)

        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets

loss计算

分类loss采用focal loss,计算过程与RetinaNet完全一样,只是样本由Anchor变成了Point。

分类loss代码实现如下:

    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]

        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()

        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)

        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))

        one_image_focal_loss = focal_weight * bce_loss

        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num

        return one_image_focal_loss

在FCOS论文中,回归loss采用IoU loss。这里我直接使用GIoU loss。由于回归loss仍然只对正样本进行计算,所以不存在预测框与真实框不相交的情况,此时GIoU loss和IoU loss是完全等同的。

回归loss代码实现如下:

    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]

        if positive_points_num == 0:
            return torch.tensor(0.).to(device)

        center_ness_targets = per_image_targets[:, 5]

        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]

        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]

        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1

        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)

        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss

        return gious_loss

最后乘以2是为了平衡回归loss与其他loss的数量级。

centerness使用bce loss进行优化。由于centerness loss的优化目标是不稳定的,在实际训练时会出现loss初期下降一点之后长期不再下降的情况,这个是正常的,不必担心。
centerness loss代码实现如下:

    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]

        if positive_points_num == 0:
            return torch.tensor(0.).to(device)

        center_ness_targets = per_image_targets[:, 5:6]

        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num

        return center_ness_loss

完整loss代码

import torch
import torch.nn as nn
import torch.nn.functional as F

INF = 100000000


class FCOSLoss(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 strides=[8, 16, 32, 64, 128],
                 mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],
                 alpha=0.25,
                 gamma=2.,
                 epsilon=1e-4):
        super(FCOSLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.image_w = image_w
        self.image_h = image_h
        self.strides = strides
        self.mi = mi

    def forward(self, cls_heads, reg_heads, center_heads, batch_positions,
                annotations):
        """
        compute cls loss, reg loss and center-ness loss in one batch
        """
        cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(
            cls_heads, reg_heads, center_heads, batch_positions, annotations)

        cls_preds = torch.sigmoid(cls_preds)
        reg_preds = torch.exp(reg_preds)
        center_preds = torch.sigmoid(center_preds)
        batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])

        device = annotations.device
        cls_loss, reg_loss, center_ness_loss = [], [], []
        valid_image_num = 0
        for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(
                cls_preds, reg_preds, center_preds, batch_targets):
            positive_points_num = (
                per_image_targets[per_image_targets[:, 4] > 0]).shape[0]
            if positive_points_num == 0:
                cls_loss.append(torch.tensor(0.).to(device))
                reg_loss.append(torch.tensor(0.).to(device))
                center_ness_loss.append(torch.tensor(0.).to(device))
            else:
                valid_image_num += 1
                one_image_cls_loss = self.compute_one_image_focal_loss(
                    per_image_cls_preds, per_image_targets)
                one_image_reg_loss = self.compute_one_image_giou_loss(
                    per_image_reg_preds, per_image_targets)
                one_image_center_ness_loss = self.compute_one_image_center_ness_loss(
                    per_image_center_preds, per_image_targets)

                cls_loss.append(one_image_cls_loss)
                reg_loss.append(one_image_reg_loss)
                center_ness_loss.append(one_image_center_ness_loss)

        cls_loss = sum(cls_loss) / valid_image_num
        reg_loss = sum(reg_loss) / valid_image_num
        center_ness_loss = sum(center_ness_loss) / valid_image_num

        return cls_loss, reg_loss, center_ness_loss

    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]

        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()

        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)

        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))

        one_image_focal_loss = focal_weight * bce_loss

        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num

        return one_image_focal_loss

    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]

        if positive_points_num == 0:
            return torch.tensor(0.).to(device)

        center_ness_targets = per_image_targets[:, 5]

        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]

        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]

        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1

        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)

        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss

        return gious_loss

    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]

        if positive_points_num == 0:
            return torch.tensor(0.).to(device)

        center_ness_targets = per_image_targets[:, 5:6]

        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num

        return center_ness_loss

    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)

        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])

            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)

        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)

        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]

            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]

                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag

                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag

                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)

                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]

                    del candidates

                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)

                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)

                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1

                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)

                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)

                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]

                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1

                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))

            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)

        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)

        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets


if __name__ == '__main__':
    from fcos import FCOS
    net = FCOS(resnet_type="resnet50")
    image_h, image_w = 600, 600
    cls_heads, reg_heads, center_heads, batch_positions = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])
    loss = FCOSLoss(image_w, image_h)
    cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,
                                           batch_positions, annotations)
    print("2222", cls_loss, reg_loss, center_loss)

你可能感兴趣的:(【庖丁解牛】从零实现FCOS(二):ground truth分配与loss计算)