[论文解读] Rank & Sort Loss for Object Detection and Instance Segmentation

文章内容

  • 问题提出
  • 相关研究现状
    • 1. 辅助头和连续标签
    • 2. 基于排序的损失
    • 3. 样本不平衡问题
  • 本文工作
    • AP Loss的不足之处
    • RS Loss 定义
    • 训练细节
  • 代码解读

论文链接:Rank & Sort Loss for Object Detection and Instance Segmentation
mmdet实现代码:Rank & Sort Loss for Object Detection and Instance Segmentation

问题提出

当下常用的损失函数形式如下,将任务 t t t在步骤 k k k的损失加权求和:
L = ∑ k ∈ K ∑ t ∈ T λ t k L t k L=\sum_{k\in K}\sum_{t\in T}\lambda_t^k L_t^k L=kKtTλtkLtk
缺点是超参数过多,很容易引起特定任务之间的不平衡,如正负样本不平衡,级联网络内部的不平衡等,最终得到次优的解决方案。

AP Loss和aLRP Loss与传统的基于分类得分的损失函数相比,具有训练过程和网络评估指标一致性(直接优化网络评估指标AP/aLRP等)、待调超参数较少、对类别不均衡不敏感等优势,但需要更长的训练时间和更多增强操作,且没有建模正样本之间的关联

有研究证明,采用辅助头对正样本定位质量进行排序,或监督分类器直接回归样本的IoU(预测定位精度)能够提高网络性能。

因此本文提出RS Loss,不仅将正样本排序在负样本之前,还基于连续的IoU值在正样本内部进行排序,这将有以下好处:

  1. 正样本内部的排序使得网络不需要辅助头实现对正样本定位质量的排序
  2. 排序性质使得网络能够在没有采样策略的前提下处理极端不平衡的数据集
  3. 借助分类得分和IoU值共同排序调优,与NMS和网络评价指标如AP具有一致性
  4. 除了学习率没有需要调优的参数

相关研究现状

1. 辅助头和连续标签

很多实验证明,用辅助器预测检测结果的定位质量、中心度、IoU、mask-IoU或置信度,并将这些预测与NMS的分类得分相结合,可以提高检测性能。还有研究发现,使用连续的IoU值比使用过辅助器监督分类器效果更好,由此产生使用连续标签训练分类器的Quality Focal Loss并表现出对类别不均衡数据集的鲁棒性。

2. 基于排序的损失

基于排序的损失不可微、难以优化。black-box solvers采用插值AP解决该问题但收效甚微;DR Loss通过对Hinge Loss引入margin实现正负排序;AP Loss和aLRP Loss对性能评估指标进行优化,通过感知学习的误差驱动算法实现不可微部分的优化,但他们需要更长的训练时间和更多的增强手段。RS Loss与之区别在于将连续的定位质量得分作为标签。

3. 样本不平衡问题

常见的解决方法是引入超参数并通过网格化搜索的方式进行调参。有实验采用自平衡策略来平衡分类和定位分支,使两者在aLRP Loss的限定范围内竞争;还有研究使用分类和定位损失的比率来平衡这些任务。
本文中,不同任务的损失值 L t k L_t^k Ltk有自己的限定范围,因此不同任务之间没有竞争关系,

本文工作

  1. 重新定义了错误驱动更新与反向传播,从而解决排序不可微的问题,能够计算排序损失
  2. 定义正负样本之间和正样本内的排序方法,解决类别不均衡问题

AP Loss的不足之处

AP Loss虽然基于排名重新定义了以AP为优化目标的损失函数,并借助感知器误差驱动优化算法实现反向传播,但有以下两处不足:

  1. 产生的损失值 L L L没有考虑目标 L i j ∗ L^∗_{ij} Lij,因此当 L i j ∗ ≠ 0 L^∗_{ij}\neq 0 Lij=0时不可解释
  2. 只计算 i ∈ P , j ∈ N i\in P, j\in N iP,jN时的损失,忽略了 i , j i,j i,j都是正样本时的类内误差,而对于使用连续标签的算法来说类内误差不可忽视,标签越大得分就应该越高。

RS Loss 定义

首先定义当前损失值 l R S ( i ) l_{RS}(i) lRS(i)为正样本 i i iranking errorsorting error 之和,其中ranking error代表正负样本间的排序损失(参考aLRP Loss形式),sorting error对正样本中得分 s j > s i s_j\gt s_i sj>si的样本进行惩罚,其在[0,1]内连续的标签值(如IoU)越大,惩罚项越小:
l R S ( i ) : = N F P ( i ) r a n k ( i ) + ∑ j ∈ P H ( x i j ) ( 1 − y j ) r a n k + ( i ) : = 得 分 比 正 样 本 i 高 的 负 样 本 个 数 正 确 排 名 + 得 分 更 高 的 正 样 本 的 l a b e l s 惩 罚 \begin{aligned} l_{RS}(i)&:=\textcolor{blue}{\frac{N_{FP}(i)}{rank(i)}}+\textcolor{orange}{\frac{\sum_{j\in P}H(x_{ij})(1-y_j)}{rank^+(i)}}\\ \\ &:=\textcolor{blue}{\frac{得分比正样本i高的负样本个数}{正确排名}}+\textcolor{orange}{得分更高的正样本的labels惩罚} \end{aligned} lRS(i):=rank(i)NFP(i)+rank+(i)jPH(xij)(1yj):=i+labels
定义目标损失值 l R S ∗ ( i ) l^*_{RS}(i) lRS(i)如下,当正样本 i i i排序在所有负样本之前时, l R ∗ ( i ) = 0 l^*_{R}(i)=0 lR(i)=0 l S ∗ ( i ) l^*_{S}(i) lS(i)对所有标签值 y j y_j yj大于样本 i i i得分的正样本的 1 − y j 1-y_j 1yj求均值,
l R S ∗ ( i ) = l R ∗ ( i ) + ∑ j ∈ P H ( x i j ) [ y j ≥ y i ] ( 1 − y j ) ∑ j ∈ P H ( x i j ) [ y j ≥ y i ] = 0 + 正 样 本 中 l a b e l 大 于 i 的 样 本 ( 1 − y i ) 之 和 正 样 本 中 l a b e l 大 于 i 的 样 本 个 数 \begin{aligned} l^*_{RS}(i)&=\textcolor{blue}{l^*_{R}(i)}+\textcolor{orange}{\frac{\sum_{j\in P}H(x_{ij})[y_j\ge y_i](1-y_j)}{\sum_{j\in P}H(x_{ij})[y_j\ge y_i]}}\\ \\ &=\textcolor{blue}{0}+\textcolor{orange}{\frac{正样本中label大于i的样本(1-y_i)之和}{正样本中label大于i的样本个数}} \end{aligned} lRS(i)=lR(i)+jPH(xij)[yjyi]jPH(xij)[yjyi](1yj)=0+labelilabeli(1yi)
RS Loss定义为正样本的当前 l R S ( i ) l_{RS}(i) lRS(i)和目标值 l R S ∗ ( i ) l^*_{RS}(i) lRS(i)差异的均值: 1 ∣ P ∣ ∑ i ∈ P ( l R S ( i ) − l R S ∗ ( i ) ) \frac{1}{|P|}\sum_{i\in P}\left( l_{RS}(i)-l^*_{RS}(i) \right) P1iP(lRS(i)lRS(i))
参照AP Loss的三步定义,在此定义兼顾了正负样本和正样本内部误差的 primary term L i j L_{ij} Lij如下:
L i j = { ( l R ( i ) − l R ∗ ( i ) ) p R ( j ∣ i ) f o r   i ∈ P , j ∈ N ( l S ( i ) − l S ∗ ( i ) ) p S ( j ∣ i ) f o r   i ∈ P , j ∈ P 0 o t h e r w i s e L_{ij}= \left\{ \begin{array} {rcl} (l_R(i)-l^*_R(i))p_R(j|i) & {for\space i\in P, j\in N}\\ (l_S(i)-l^*_S(i))p_S(j|i) & {for\space i\in P, j\in P}\\ 0 & {otherwise} \end{array} \right. Lij=(lR(i)lR(i))pR(ji)(lS(i)lS(i))pS(ji)0for iP,jNfor iP,jPotherwise
其中 p R ( j ∣ i ) , p S ( j ∣ i ) p_R(j|i),p_S(j|i) pR(ji),pS(ji)负责将样本 i i i上的误差分别分布在导致误差的样本 j j j上(如只有得分 s j > s i s_j > s_i sj>si的负样本 j j j会引起 ranking error;只有得分 s j > s i s_j > s_i sj>si且标签值 y j < y i y_j < y_i yj<yi的正样本 j j j会引起 sorting error),即:
p R ( j ∣ i ) = H ( x i j ) ∑ k ∈ N H ( x i k ) p S ( j ∣ i ) = H ( x i j ) [ y j < y i ] ∑ k ∈ P H ( x i k ) [ y k < y i ] \begin{aligned} &p_R(j|i)=\frac{H(x_{ij})}{\sum_{k\in N}H(x_{ik})} \\ &p_S(j|i)=\frac{H(x_{ij})[y_jpR(ji)=kNH(xik)H(xij)pS(ji)=kPH(xik)[yk<yi]H(xij)[yj<yi]
因此,Identity Update过程可以分ranking errorsorting error两部分进行:正样本同时受到两部分的影响,而负样本的定位精度对RS Loss没有影响,因此只被ranking error更新。

  • 对于正样本,其反向传播梯度 ∂ L R S ∂ s i \frac{\partial L_{RS}}{\partial s_i} siLRS不仅包括样本 i i i本身的ranking errorsorting error,这部分称为promotion update signal;此外还受到得分更高但连续标签值更小的其他正样本(missorted samples)的影响,这部分称为demotion update signal,与promotion符号相反,根据misranked样本 j j j的信号推动该样本 i i i的信息更新。
    1 ∣ P ∣ ( l R S ∗ ( i ) − l R S ( i ) ⏟ p r o m o t i o n   u p d a t e   s i g n a l + ∑ j ∈ P ( l S ( j ) − l S ∗ ( j ) ) p S ( i ∣ j ) ⏟ d e m o t i o n   u p d a t e   s i g n a l ) \frac{1}{|P|}\left( \underbrace{l^*_{RS}(i)-l_{RS}(i)}_{promotion\space update\space signal}+\underbrace{\sum_{j\in P}\left(l_{S}(j)-l^*_{S}(j)\right)p_S(i|j)}_{demotion\space update\space signal}\right) P1promotion update signal lRS(i)lRS(i)+demotion update signal jP(lS(j)lS(j))pS(ij)
  • 对于负样本,其损失值只受到排名损失ranking error的反向传播影响,定义如下:
    ∂ L R S ∂ s i = 1 ∣ P ∣ ∑ j ∈ P l R ( j ) p R ( i ∣ j ) \frac{\partial L_{RS}}{\partial s_i}=\frac{1}{|P|}\sum_{j\in P} l_R(j)p_R(i|j) siLRS=P1jPlR(j)pR(ij)

训练细节

ATSS的损失定义为 L A T S S = L c l s + λ b o x L b o x + λ c t r L c t r L_{ATSS}=L_{cls}+\lambda_{box}L_{box}+\lambda_{ctr}L_{ctr} LATSS=Lcls+λboxLbox+λctrLctr,其中三个分量分别代表分类损失、定位损失(GIOU)和中心点定位损失(交叉熵损失)。
在此删除网络中的辅助头,并用RS Loss替代分类损失,其中连续标签值为目标框与真值框的IoU值,得到 L R S − A T S S = L R S + λ b o x L b o x L_{RS-ATSS}=L_{RS}+\lambda_{box}L_{box} LRSATSS=LRS+λboxLbox,超参数 λ b o x \lambda_{box} λbox通常通过网格搜索设置为常数,在此采用两种启发式无调优算法来确定每个iteration的 λ b o x \lambda_{box} λbox值:

  • 基于损失值: λ b o x = L R S / L b o x \lambda_{box}=L_{RS}/L_{box} λbox=LRS/Lbox
  • 基于梯度: λ b o x = ∣ ∂ L R S ∂ s ∣ / ∣ ∂ L b o x ∂ b ∣ \lambda_{box}=|\frac{\partial L_{RS}}{\partial s}|/|\frac{\partial L_{box}}{\partial b}| λbox=sLRS/bLbox,其中 b , s b,s b,s分别是目标框回归和分类头的输出。

本文实验发现基于损失值的 λ b o x \lambda_{box} λbox设置方法与调参效果相近;并且本文用每个预测框的分类得分对它们的定位GIoU损失进行加权。这两个tricks(基于数值的任务平衡、基于得分的实例加权)都与超参数无关,可以应用于所有网络。

RS Loss 使用总结

  • 使用RS Loss时,通常会删除启发式/随机采样策略,并移除网络的 IoU aux. head 或 IoU Net,直接用预测结果的 IoU 来监督分类 head;
  • 会在预测框回归中使用基于分数的加权策略,并倾向于使用Dice Loss而非交叉熵损失来训练用于实例分割的mask prediction head(Dice Loss有界,且和GIOU一样对各种因素考虑周全)
  • 每个iteration后设置损失函数 L = ∑ k ∈ K ∑ t ∈ T λ t k L t k L=\sum_{k\in K}\sum_{t\in T}\lambda_t^k L_t^k L=kKtTλtkLtk中各分量的权重为 L c l s k L t k \frac{L^k_{cls}}{L^k_t} LtkLclsk,只有RPN例外( × 0.2 \times0.2 ×0.2
    [论文解读] Rank & Sort Loss for Object Detection and Instance Segmentation_第1张图片

代码解读

class RankSort(torch.autograd.Function):

    @staticmethod
    def forward(ctx, logits, targets, delta_RS=0.50, eps=1e-10):
        # ---------------------------------------------------------#
        # targets: continuous label for each sample (e.g. IoU)
        # delta_RS: parameter of piecewise step functions
        # ---------------------------------------------------------#

        classification_grads = torch.zeros(logits.shape).cuda()

        # ---------------------#
        # Filter fg logits
        # ---------------------#
        fg_labels = (targets > 0.)
        fg_logits = logits[fg_labels]
        fg_targets = targets[fg_labels]
        fg_num = len(fg_logits)

        sorting_error = torch.zeros(fg_num).cuda()
        ranking_error = torch.zeros(fg_num).cuda()
        fg_grad = torch.zeros(fg_num).cuda()

        # --------------------------------------#
        # Filter non-trivial negative samples
        # --------------------------------------#
        # Do not use bg with scores less than minimum fg logit
        # since changing its score does not have an effect on precision
        threshold_logit = torch.min(fg_logits)-delta_RS
        relevant_bg_labels = ((targets == 0) & (logits >= threshold_logit))
        relevant_bg_logits = logits[relevant_bg_labels]
        relevant_bg_grad = torch.zeros(len(relevant_bg_logits)).cuda()

        # -----------------------------#
        # Loop on posivite indices
        # -----------------------------#
        # sort the fg logits and loops over each positive following the order
        order = torch.argsort(fg_logits)
        for ii in order:
            # --------------------------------#
            # Difference Transforms (x_ij)
            # --------------------------------#
            fg_relations = fg_logits - fg_logits[ii]
            bg_relations = relevant_bg_logits - fg_logits[ii]

            # ----------------------#
            # H(x_ij) \in [0,1]
            # ----------------------#
            # piecewise step functions
            if delta_RS > 0:
                fg_relations = torch.clamp(fg_relations / (2 * delta_RS) + 0.5, min=0, max=1)
                bg_relations = torch.clamp(bg_relations / (2 * delta_RS) + 0.5, min=0, max=1)
            # common function
            else:
                fg_relations = (fg_relations >= 0).float()
                bg_relations = (bg_relations >= 0).float()

            # Rank of ii among pos and false positive number (bg with larger scores)
            rank_pos = torch.sum(fg_relations)
            FP_num = torch.sum(bg_relations)
            # Rank of ii among all examples
            rank = rank_pos + FP_num

            # ------------------#
            # Ranking error
            # ------------------#
            # Since target_ranking_error is always 0, we here store current_ranking_error as ranking_error
            ranking_error[ii] = FP_num/rank

            # ------------------------#
            # Current sorting error
            # ------------------------#
            current_sorting_error = torch.sum(fg_relations * (1 - fg_targets)) / rank_pos

            # -------------------------------------------------------------------#
            # Find examples ranking higher and targets larger than ii
            # -------------------------------------------------------------------#
            iou_relations = (fg_targets >= fg_targets[ii])
            target_sorted_order = iou_relations * fg_relations
            # The rank of ii among positives in sorted order
            rank_pos_target = torch.sum(target_sorted_order)

            # ------------------------------------------------------#
            # Target sorting error  (target ranking error is 0)
            # ------------------------------------------------------#
            # Since target_ranking_error is always 0, target_sorting_error is also the target_error
            target_sorting_error = torch.sum(target_sorted_order * (1 - fg_targets)) / rank_pos_target

            # ------------------#
            # Sorting error
            # ------------------#
            sorting_error[ii] = current_sorting_error - target_sorting_error

            # -------------------------------------#
            # Identity Update for Ranking Error
            # -------------------------------------#
            if FP_num > eps:
                # For ii the update is the ranking error
                fg_grad[ii] -= ranking_error[ii]
                # For negatives, distribute error via ranking pmf (i.e. bg_relations/FP_num)
                relevant_bg_grad += (bg_relations * (ranking_error[ii]/FP_num))

            # --------------------------------------------------------------#
            # Find examples ranking higher but targets smaller than ii
            # --------------------------------------------------------------#
            # Find the positives that are misranked (the ones with smaller IoU but larger logits)
            missorted_examples = (~ iou_relations) * fg_relations

            # Denominotor of sorting pmf
            sorting_pmf_denom = torch.sum(missorted_examples)

            # -------------------------------------#
            # Identity Update for Sorting Error
            # -------------------------------------#
            if sorting_pmf_denom > eps:
                # For ii the update is the sorting error
                fg_grad[ii] -= sorting_error[ii]
                # For positives, distribute error via sorting pmf (i.e. missorted_examples/sorting_pmf_denom)
                fg_grad += (missorted_examples * (sorting_error[ii] / sorting_pmf_denom))

        # Normalize gradients by number of positives
        classification_grads[fg_labels] = (fg_grad/fg_num)
        classification_grads[relevant_bg_labels] = (relevant_bg_grad/fg_num)

        ctx.save_for_backward(classification_grads)

        return ranking_error.mean(), sorting_error.mean()

    @staticmethod
    def backward(ctx, out_grad1, out_grad2):
        g1, = ctx.saved_tensors
        return g1*out_grad1, None, None, None

你可能感兴趣的:(目标检测,深度学习,计算机视觉,回归,分类)