RetinaNet: Focal loss在目标检测网络中的应用

介绍

RetinaNet是2018年Facebook AI团队在目标检测领域新的贡献。它的重要作者名单中Ross Girshick与Kaiming He赫然在列。来自Microsoft的Sun Jian团队与现在Facebook的Ross/Kaiming团队在当前视觉目标分类、检测领域有着北乔峰、南慕容一般的独特地位。这两个实验室的文章多是行业里前进方向的提示牌。

RetinaNet只是原来FPN网络与FCN网络的组合应用,因此在目标网络检测框架上它并无特别亮眼创新。文章中最大的创新来自于Focal loss的提出及在单阶段目标检测网络RetinaNet(实质为Resnet + FPN + FCN)的成功应用。Focal loss是一种改进了的交叉熵(cross-entropy, CE)loss,它通过在原有的CE loss上乘了个使易检测目标对模型训练贡献削弱的指数式,从而使得Focal loss成功地解决了在目标检测时,正负样本区域极不平衡而目标检测loss易被大批量负样本所左右的问题。此问题是单阶段目标检测框架(如SSD/Yolo系列)与双阶段目标检测框架(如Faster-RCNN/R-FCN等)accuracy gap的最大原因。在Focal loss提出之前,已有的目标检测网络都是通过像Boot strapping/Hard example mining等方法来解决此问题的。作者通过后续实验成功表明Focal loss可在单阶段目标检测网络中成功使用,并最终能以更快的速率实现与双阶段目标检测网络近似或更优的效果。

类别不平衡问题

常规的单阶段目标检测网络像SSD一般在模型训练时会先大密度地在模型终端的系列feature maps上生成出10,000甚至100,0000个目标候选区域。然后再分别对这些候选区域进行分类与位置回归识别。而在这些生成的数万个候选区域中,绝大多数都是不包含待检测目标的图片背景,这样就造成了机器学习中经典的训练样本正负不平衡的问题。它往往会造成最终算出的training loss为占绝对多数但包含信息量却很少的负样本所支配,少样正样本提供的关键信息却不能在一般所用的training loss中发挥正常作用,从而无法得出一个能对模型训练提供正确指导的loss。

常用的解决此问题的方法就是负样本挖掘。或其它更复杂的用于过滤负样本从而使正负样本数维持一定比率的样本取样方法。而在此篇文章中作者提出了可通过候选区域包含潜在目标概率进而对最终的training loss进行较正的方法。实验表明这种新提出的focal loss在单阶段目标检测任务上表现突出,有效地解决了此领域里面潜在的类别不平衡问题。

Focal loss

  • CE(cross-entropy) loss

以下为典型的交叉熵loss,它广泛用于当下的图像分类、检测CNN网络当中。

RetinaNet: Focal loss在目标检测网络中的应用_第1张图片
Cross-entropy_loss
  • Balanced CE loss

考虑到上节中提到的类别不平衡问题对最终training loss的不利影响,我们自然会想到可通过在loss公式中使用与目标存在概率成反比的系数对其进行较正。如下公式即是此朴素想法的体现。它也是作者最终Focus loss的baseline。

RetinaNet: Focal loss在目标检测网络中的应用_第2张图片
Balanced_CE-loss
  • Focal loss定义

以下是作者提出的focal loss的想法。

RetinaNet: Focal loss在目标检测网络中的应用_第3张图片
Focal_loss定义

下图为focal loss与常规CE loss的对比。从中,我们易看出focal loss所加的指数式系数可对正负样本对loss的贡献自动调节。当某样本类别比较明确些,它对整体loss的贡献就比较少;而若某样本类别不易区分,则对整体loss的贡献就相对偏大。这样得到的loss最终将集中精力去诱导模型去努力分辨那些难分的目标类别,于是就有效提升了整体的目标检测准度。不过在此focus loss计算当中,我们引入了一个新的hyper parameter即γ。一般来说新参数的引入,往往会伴随着模型使用难度的增加。在本文中,作者有试者对其进行调节,线性搜索后得出将γ设为2时,模型检测效果最好。

RetinaNet: Focal loss在目标检测网络中的应用_第4张图片
几种loss的对比

在最终所用的focal loss上,作者还引入了α系数,它能够使得focal loss对不同类别更加平衡。实验表明它会比原始的focal loss效果更好。

RetinaNet: Focal loss在目标检测网络中的应用_第5张图片
最终所用的Focal_loss

模型的初始化参数选择

一般我们初始化CNN网络模型时都会使用无偏的参数对其初始化,比如Conv的kernel 参数我们会以bias 为0,variance为0.01的某分布来对其初始化。但是如果我们的模型要去处理类别极度不平衡的情况,那么就会考虑到这样对训练数据分布无任选先验假设的初始化会使得在训练过程中,我们的参数更偏向于拥有更多数量的负样本的情况去进化。作者观察下来发现它在训练时会出现极度的不稳定。于是作者在初始化模型最后一层参数时考虑了数据样本分布的不平衡性,这样使得初始训练时最终得出的loss不会对过多的负样本数量所惊讶到,从而有效地规避了初始训练时模型的震荡与不稳定。

RetinaNet检测框架

RetinaNet本质上是Resnet + FPN + 两个FCN子网络。
以下为RetinaNet目标框架框架图。有了之前blog里面提到的FPN与FCN的知识后,我们很容易理解此框架的设计含义。

RetinaNet: Focal loss在目标检测网络中的应用_第6张图片
RetinaNet目标检测框架

一般主干网络可选用任一有效的特征提取网络如vgg16或resnet系列,此处作者分别尝试了resnet-50与resnet-101。而FPN则是对resnet-50里面自动形成的多尺度特征进行了强化利用,从而得到了表达力更强、包含多尺度目标区域信息的feature maps集合。最后在FPN所吐出的feature maps集合上,分别使用了两个FCN子网络(它们有着相同的网络结构却各自独立,并不share参数)用来完成目标框类别分类与位置回归任务。

模型的推理与训练

  • 模型推理

一旦我们有了训练好的模型,在正式部署时,只需对其作一次forward,然后对最终生成的目标区域进行过渡。然后只对每个FPN level上目标存在概率最高的前1000个目标框进一步地decoding处理。接下来再将所有FPN level上得到的目标框汇集起来,统一使用极大值抑制的方法进一步过渡(其中极大值抑制时所用的阈值为0.5)。这样,我们就得到了最终的目标与其位置框架。

  • 模型训练

模型训练中主要在后端Loss计算时采用了Focal loss,另外也在模型初始化时考虑到了正负样本极度不平衡的情况进而对模型最后一个conv layer的bias参数作了有偏初始化。

训练时用了SGD,mini batch size为16,在8个GPU上一块训练,每个GPU上local batch size为2。最大iterations数目为90,000;模型初始lr为0.01,接下来随着训练进行分step wisely 降低。真正的training loss则为表达目标类别的focus loss与表达目标框位置回归信息的L1 loss的和。

下图为RetinaNet模型的检测准度与性能。

RetinaNet: Focal loss在目标检测网络中的应用_第7张图片
RetinaNet的检测准度与性能

代码实例

以下函数用于从FPN的各个level的feature maps上提取各种scale的anchor box。

def _create_cell_anchors():
    """
    Generate all types of anchors for all fpn levels/scales/aspect ratios.
    This function is called only once at the beginning of inference.
    """
    k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
    scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE
    aspect_ratios = cfg.RETINANET.ASPECT_RATIOS
    anchor_scale = cfg.RETINANET.ANCHOR_SCALE
    A = scales_per_octave * len(aspect_ratios)
    anchors = {}
    for lvl in range(k_min, k_max + 1):
        # create cell anchors array
        stride = 2. ** lvl
        cell_anchors = np.zeros((A, 4))
        a = 0
        for octave in range(scales_per_octave):
            octave_scale = 2 ** (octave / float(scales_per_octave))
            for aspect in aspect_ratios:
                anchor_sizes = (stride * octave_scale * anchor_scale, )
                anchor_aspect_ratios = (aspect, )
                cell_anchors[a, :] = generate_anchors(
                    stride=stride, sizes=anchor_sizes,
                    aspect_ratios=anchor_aspect_ratios)
                a += 1
        anchors[lvl] = cell_anchors
    return anchors

下面函数则描述了如何使用train好的RetinaNet来进行图片目标检测。

def im_detect_bbox(model, im, timers=None):
    """Generate RetinaNet detections on a single image."""
    if timers is None:
        timers = defaultdict(Timer)
    # Although anchors are input independent and could be precomputed,
    # recomputing them per image only brings a small overhead
    anchors = _create_cell_anchors()
    timers['im_detect_bbox'].tic()
    k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
    A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
    inputs = {}
    inputs['data'], im_scale, inputs['im_info'] = \
        blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
    cls_probs, box_preds = [], []
    for lvl in range(k_min, k_max + 1):
        suffix = 'fpn{}'.format(lvl)
        cls_probs.append(core.ScopedName('retnet_cls_prob_{}'.format(suffix)))
        box_preds.append(core.ScopedName('retnet_bbox_pred_{}'.format(suffix)))
    for k, v in inputs.items():
        workspace.FeedBlob(core.ScopedName(k), v.astype(np.float32, copy=False))

    workspace.RunNet(model.net.Proto().name)
    cls_probs = workspace.FetchBlobs(cls_probs)
    box_preds = workspace.FetchBlobs(box_preds)

    # here the boxes_all are [x0, y0, x1, y1, score]
    boxes_all = defaultdict(list)

    cnt = 0
    for lvl in range(k_min, k_max + 1):
        # create cell anchors array
        stride = 2. ** lvl
        cell_anchors = anchors[lvl]

        # fetch per level probability
        cls_prob = cls_probs[cnt]
        box_pred = box_preds[cnt]
        cls_prob = cls_prob.reshape((
            cls_prob.shape[0], A, int(cls_prob.shape[1] / A),
            cls_prob.shape[2], cls_prob.shape[3]))
        box_pred = box_pred.reshape((
            box_pred.shape[0], A, 4, box_pred.shape[2], box_pred.shape[3]))
        cnt += 1

        if cfg.RETINANET.SOFTMAX:
            cls_prob = cls_prob[:, :, 1::, :, :]

        cls_prob_ravel = cls_prob.ravel()
        # In some cases [especially for very small img sizes], it's possible that
        # candidate_ind is empty if we impose threshold 0.05 at all levels. This
        # will lead to errors since no detections are found for this image. Hence,
        # for lvl 7 which has small spatial resolution, we take the threshold 0.0
        th = cfg.RETINANET.INFERENCE_TH if lvl < k_max else 0.0
        candidate_inds = np.where(cls_prob_ravel > th)[0]
        if (len(candidate_inds) == 0):
            continue

        pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds))
        inds = np.argpartition(
            cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:]
        inds = candidate_inds[inds]

        inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose()
        classes = inds_5d[:, 2]
        anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4]
        scores = cls_prob[:, anchor_ids, classes, y, x]

        boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32)
        boxes *= stride
        boxes += cell_anchors[anchor_ids, :]

        if not cfg.RETINANET.CLASS_SPECIFIC_BBOX:
            box_deltas = box_pred[0, anchor_ids, :, y, x]
        else:
            box_cls_inds = classes * 4
            box_deltas = np.vstack(
                [box_pred[0, ind:ind + 4, yi, xi]
                 for ind, yi, xi in zip(box_cls_inds, y, x)]
            )
        pred_boxes = (
            box_utils.bbox_transform(boxes, box_deltas)
            if cfg.TEST.BBOX_REG else boxes)
        pred_boxes /= im_scale
        pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape)
        box_scores = np.zeros((pred_boxes.shape[0], 5))
        box_scores[:, 0:4] = pred_boxes
        box_scores[:, 4] = scores

        for cls in range(1, cfg.MODEL.NUM_CLASSES):
            inds = np.where(classes == cls - 1)[0]
            if len(inds) > 0:
                boxes_all[cls].extend(box_scores[inds, :])
    timers['im_detect_bbox'].toc()

    # Combine predictions across all levels and retain the top scoring by class
    timers['misc_bbox'].tic()
    detections = []
    for cls, boxes in boxes_all.items():
        cls_dets = np.vstack(boxes).astype(dtype=np.float32)
        # do class specific nms here
        keep = box_utils.nms(cls_dets, cfg.TEST.NMS)
        cls_dets = cls_dets[keep, :]
        out = np.zeros((len(keep), 6))
        out[:, 0:5] = cls_dets
        out[:, 5].fill(cls)
        detections.append(out)

    # detections (N, 6) format:
    #   detections[:, :4] - boxes
    #   detections[:, 4] - scores
    #   detections[:, 5] - classes
    detections = np.vstack(detections)
    # sort all again
    inds = np.argsort(-detections[:, 4])
    detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :]

    # Convert the detections to image cls_ format (see core/test_engine.py)
    num_classes = cfg.MODEL.NUM_CLASSES
    cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    for c in range(1, num_classes):
        inds = np.where(detections[:, 5] == c)[0]
        cls_boxes[c] = detections[inds, :5]
    timers['misc_bbox'].toc()

    return cls_boxes

以下为RetinaNet中training loss的具体计算。可以看出它包含了两个部分分别为反映位置信息的L1 loss与反映类别信息的focus loss。

def add_fpn_retinanet_losses(model):
    loss_gradients = {}
    gradients, losses = [], []

    k_max = cfg.FPN.RPN_MAX_LEVEL  # coarsest level of pyramid
    k_min = cfg.FPN.RPN_MIN_LEVEL  # finest level of pyramid

    model.AddMetrics(['retnet_fg_num', 'retnet_bg_num'])
    # ==========================================================================
    # bbox regression loss - SelectSmoothL1Loss for multiple anchors at a location
    # ==========================================================================
    for lvl in range(k_min, k_max + 1):
        suffix = 'fpn{}'.format(lvl)
        bbox_loss = model.net.SelectSmoothL1Loss(
            [
                'retnet_bbox_pred_' + suffix,
                'retnet_roi_bbox_targets_' + suffix,
                'retnet_roi_fg_bbox_locs_' + suffix, 'retnet_fg_num'
            ],
            'retnet_loss_bbox_' + suffix,
            beta=cfg.RETINANET.BBOX_REG_BETA,
            scale=model.GetLossScale() * cfg.RETINANET.BBOX_REG_WEIGHT
        )
        gradients.append(bbox_loss)
        losses.append('retnet_loss_bbox_' + suffix)

    # ==========================================================================
    # cls loss - depends on softmax/sigmoid outputs
    # ==========================================================================
    for lvl in range(k_min, k_max + 1):
        suffix = 'fpn{}'.format(lvl)
        cls_lvl_logits = 'retnet_cls_pred_' + suffix
        if not cfg.RETINANET.SOFTMAX:
            cls_focal_loss = model.net.SigmoidFocalLoss(
                [
                    cls_lvl_logits, 'retnet_cls_labels_' + suffix,
                    'retnet_fg_num'
                ],
                ['fl_{}'.format(suffix)],
                gamma=cfg.RETINANET.LOSS_GAMMA,
                alpha=cfg.RETINANET.LOSS_ALPHA,
                scale=model.GetLossScale(),
                num_classes=model.num_classes - 1
            )
            gradients.append(cls_focal_loss)
            losses.append('fl_{}'.format(suffix))
        else:
            cls_focal_loss, gated_prob = model.net.SoftmaxFocalLoss(
                [
                    cls_lvl_logits, 'retnet_cls_labels_' + suffix,
                    'retnet_fg_num'
                ],
                ['fl_{}'.format(suffix), 'retnet_prob_{}'.format(suffix)],
                gamma=cfg.RETINANET.LOSS_GAMMA,
                alpha=cfg.RETINANET.LOSS_ALPHA,
                scale=model.GetLossScale(),
                num_classes=model.num_classes
            )
            gradients.append(cls_focal_loss)
            losses.append('fl_{}'.format(suffix))

    loss_gradients.update(blob_utils.get_loss_gradients(model, gradients))
    model.AddLosses(losses)
    return loss_gradients

参考文献

  • Focal Loss for Dense Object Detection, Tsung-Yi Lin, 2018
  • https://github.com/facebookresearch/Detectron

你可能感兴趣的:(RetinaNet: Focal loss在目标检测网络中的应用)