论文链接:Rank & Sort Loss for Object Detection and Instance Segmentation
mmdet实现代码:Rank & Sort Loss for Object Detection and Instance Segmentation
当下常用的损失函数形式如下,将任务 t t t在步骤 k k k的损失加权求和:
L = ∑ k ∈ K ∑ t ∈ T λ t k L t k L=\sum_{k\in K}\sum_{t\in T}\lambda_t^k L_t^k L=k∈K∑t∈T∑λtkLtk
缺点是超参数过多,很容易引起特定任务之间的不平衡,如正负样本不平衡,级联网络内部的不平衡等,最终得到次优的解决方案。
AP Loss和aLRP Loss与传统的基于分类得分的损失函数相比,具有训练过程和网络评估指标一致性(直接优化网络评估指标AP/aLRP等)、待调超参数较少、对类别不均衡不敏感等优势,但需要更长的训练时间和更多增强操作,且没有建模正样本之间的关联。
有研究证明,采用辅助头对正样本定位质量进行排序,或监督分类器直接回归样本的IoU(预测定位精度)能够提高网络性能。
因此本文提出RS Loss,不仅将正样本排序在负样本之前,还基于连续的IoU值在正样本内部进行排序,这将有以下好处:
很多实验证明,用辅助器预测检测结果的定位质量、中心度、IoU、mask-IoU或置信度,并将这些预测与NMS的分类得分相结合,可以提高检测性能。还有研究发现,使用连续的IoU值比使用过辅助器监督分类器效果更好,由此产生使用连续标签训练分类器的Quality Focal Loss并表现出对类别不均衡数据集的鲁棒性。
基于排序的损失不可微、难以优化。black-box solvers采用插值AP解决该问题但收效甚微;DR Loss通过对Hinge Loss引入margin实现正负排序;AP Loss和aLRP Loss对性能评估指标进行优化,通过感知学习的误差驱动算法实现不可微部分的优化,但他们需要更长的训练时间和更多的增强手段。RS Loss与之区别在于将连续的定位质量得分作为标签。
常见的解决方法是引入超参数并通过网格化搜索的方式进行调参。有实验采用自平衡策略来平衡分类和定位分支,使两者在aLRP Loss的限定范围内竞争;还有研究使用分类和定位损失的比率来平衡这些任务。
本文中,不同任务的损失值 L t k L_t^k Ltk有自己的限定范围,因此不同任务之间没有竞争关系,
AP Loss虽然基于排名重新定义了以AP为优化目标的损失函数,并借助感知器误差驱动优化算法实现反向传播,但有以下两处不足:
首先定义当前损失值 l R S ( i ) l_{RS}(i) lRS(i)为正样本 i i i的 ranking error 和 sorting error 之和,其中ranking error代表正负样本间的排序损失(参考aLRP Loss形式),sorting error对正样本中得分 s j > s i s_j\gt s_i sj>si的样本进行惩罚,其在[0,1]内连续的标签值(如IoU)越大,惩罚项越小:
l R S ( i ) : = N F P ( i ) r a n k ( i ) + ∑ j ∈ P H ( x i j ) ( 1 − y j ) r a n k + ( i ) : = 得 分 比 正 样 本 i 高 的 负 样 本 个 数 正 确 排 名 + 得 分 更 高 的 正 样 本 的 l a b e l s 惩 罚 \begin{aligned} l_{RS}(i)&:=\textcolor{blue}{\frac{N_{FP}(i)}{rank(i)}}+\textcolor{orange}{\frac{\sum_{j\in P}H(x_{ij})(1-y_j)}{rank^+(i)}}\\ \\ &:=\textcolor{blue}{\frac{得分比正样本i高的负样本个数}{正确排名}}+\textcolor{orange}{得分更高的正样本的labels惩罚} \end{aligned} lRS(i):=rank(i)NFP(i)+rank+(i)∑j∈PH(xij)(1−yj):=正确排名得分比正样本i高的负样本个数+得分更高的正样本的labels惩罚
定义目标损失值 l R S ∗ ( i ) l^*_{RS}(i) lRS∗(i)如下,当正样本 i i i排序在所有负样本之前时, l R ∗ ( i ) = 0 l^*_{R}(i)=0 lR∗(i)=0, l S ∗ ( i ) l^*_{S}(i) lS∗(i)对所有标签值 y j y_j yj大于样本 i i i得分的正样本的 1 − y j 1-y_j 1−yj求均值,
l R S ∗ ( i ) = l R ∗ ( i ) + ∑ j ∈ P H ( x i j ) [ y j ≥ y i ] ( 1 − y j ) ∑ j ∈ P H ( x i j ) [ y j ≥ y i ] = 0 + 正 样 本 中 l a b e l 大 于 i 的 样 本 ( 1 − y i ) 之 和 正 样 本 中 l a b e l 大 于 i 的 样 本 个 数 \begin{aligned} l^*_{RS}(i)&=\textcolor{blue}{l^*_{R}(i)}+\textcolor{orange}{\frac{\sum_{j\in P}H(x_{ij})[y_j\ge y_i](1-y_j)}{\sum_{j\in P}H(x_{ij})[y_j\ge y_i]}}\\ \\ &=\textcolor{blue}{0}+\textcolor{orange}{\frac{正样本中label大于i的样本(1-y_i)之和}{正样本中label大于i的样本个数}} \end{aligned} lRS∗(i)=lR∗(i)+∑j∈PH(xij)[yj≥yi]∑j∈PH(xij)[yj≥yi](1−yj)=0+正样本中label大于i的样本个数正样本中label大于i的样本(1−yi)之和
RS Loss定义为正样本的当前 l R S ( i ) l_{RS}(i) lRS(i)和目标值 l R S ∗ ( i ) l^*_{RS}(i) lRS∗(i)差异的均值: 1 ∣ P ∣ ∑ i ∈ P ( l R S ( i ) − l R S ∗ ( i ) ) \frac{1}{|P|}\sum_{i\in P}\left( l_{RS}(i)-l^*_{RS}(i) \right) ∣P∣1∑i∈P(lRS(i)−lRS∗(i))
参照AP Loss的三步定义,在此定义兼顾了正负样本和正样本内部误差的 primary term L i j L_{ij} Lij如下:
L i j = { ( l R ( i ) − l R ∗ ( i ) ) p R ( j ∣ i ) f o r i ∈ P , j ∈ N ( l S ( i ) − l S ∗ ( i ) ) p S ( j ∣ i ) f o r i ∈ P , j ∈ P 0 o t h e r w i s e L_{ij}= \left\{ \begin{array} {rcl} (l_R(i)-l^*_R(i))p_R(j|i) & {for\space i\in P, j\in N}\\ (l_S(i)-l^*_S(i))p_S(j|i) & {for\space i\in P, j\in P}\\ 0 & {otherwise} \end{array} \right. Lij=⎩⎨⎧(lR(i)−lR∗(i))pR(j∣i)(lS(i)−lS∗(i))pS(j∣i)0for i∈P,j∈Nfor i∈P,j∈Potherwise
其中 p R ( j ∣ i ) , p S ( j ∣ i ) p_R(j|i),p_S(j|i) pR(j∣i),pS(j∣i)负责将样本 i i i上的误差分别分布在导致误差的样本 j j j上(如只有得分 s j > s i s_j > s_i sj>si的负样本 j j j会引起 ranking error;只有得分 s j > s i s_j > s_i sj>si且标签值 y j < y i y_j < y_i yj<yi的正样本 j j j会引起 sorting error),即:
p R ( j ∣ i ) = H ( x i j ) ∑ k ∈ N H ( x i k ) p S ( j ∣ i ) = H ( x i j ) [ y j < y i ] ∑ k ∈ P H ( x i k ) [ y k < y i ] \begin{aligned} &p_R(j|i)=\frac{H(x_{ij})}{\sum_{k\in N}H(x_{ik})} \\ &p_S(j|i)=\frac{H(x_{ij})[y_j
因此,Identity Update过程可以分ranking error和 sorting error两部分进行:正样本同时受到两部分的影响,而负样本的定位精度对RS Loss没有影响,因此只被ranking error更新。
ATSS的损失定义为 L A T S S = L c l s + λ b o x L b o x + λ c t r L c t r L_{ATSS}=L_{cls}+\lambda_{box}L_{box}+\lambda_{ctr}L_{ctr} LATSS=Lcls+λboxLbox+λctrLctr,其中三个分量分别代表分类损失、定位损失(GIOU)和中心点定位损失(交叉熵损失)。
在此删除网络中的辅助头,并用RS Loss替代分类损失,其中连续标签值为目标框与真值框的IoU值,得到 L R S − A T S S = L R S + λ b o x L b o x L_{RS-ATSS}=L_{RS}+\lambda_{box}L_{box} LRS−ATSS=LRS+λboxLbox,超参数 λ b o x \lambda_{box} λbox通常通过网格搜索设置为常数,在此采用两种启发式无调优算法来确定每个iteration的 λ b o x \lambda_{box} λbox值:
本文实验发现基于损失值的 λ b o x \lambda_{box} λbox设置方法与调参效果相近;并且本文用每个预测框的分类得分对它们的定位GIoU损失进行加权。这两个tricks(基于数值的任务平衡、基于得分的实例加权)都与超参数无关,可以应用于所有网络。
RS Loss 使用总结
class RankSort(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, targets, delta_RS=0.50, eps=1e-10):
# ---------------------------------------------------------#
# targets: continuous label for each sample (e.g. IoU)
# delta_RS: parameter of piecewise step functions
# ---------------------------------------------------------#
classification_grads = torch.zeros(logits.shape).cuda()
# ---------------------#
# Filter fg logits
# ---------------------#
fg_labels = (targets > 0.)
fg_logits = logits[fg_labels]
fg_targets = targets[fg_labels]
fg_num = len(fg_logits)
sorting_error = torch.zeros(fg_num).cuda()
ranking_error = torch.zeros(fg_num).cuda()
fg_grad = torch.zeros(fg_num).cuda()
# --------------------------------------#
# Filter non-trivial negative samples
# --------------------------------------#
# Do not use bg with scores less than minimum fg logit
# since changing its score does not have an effect on precision
threshold_logit = torch.min(fg_logits)-delta_RS
relevant_bg_labels = ((targets == 0) & (logits >= threshold_logit))
relevant_bg_logits = logits[relevant_bg_labels]
relevant_bg_grad = torch.zeros(len(relevant_bg_logits)).cuda()
# -----------------------------#
# Loop on posivite indices
# -----------------------------#
# sort the fg logits and loops over each positive following the order
order = torch.argsort(fg_logits)
for ii in order:
# --------------------------------#
# Difference Transforms (x_ij)
# --------------------------------#
fg_relations = fg_logits - fg_logits[ii]
bg_relations = relevant_bg_logits - fg_logits[ii]
# ----------------------#
# H(x_ij) \in [0,1]
# ----------------------#
# piecewise step functions
if delta_RS > 0:
fg_relations = torch.clamp(fg_relations / (2 * delta_RS) + 0.5, min=0, max=1)
bg_relations = torch.clamp(bg_relations / (2 * delta_RS) + 0.5, min=0, max=1)
# common function
else:
fg_relations = (fg_relations >= 0).float()
bg_relations = (bg_relations >= 0).float()
# Rank of ii among pos and false positive number (bg with larger scores)
rank_pos = torch.sum(fg_relations)
FP_num = torch.sum(bg_relations)
# Rank of ii among all examples
rank = rank_pos + FP_num
# ------------------#
# Ranking error
# ------------------#
# Since target_ranking_error is always 0, we here store current_ranking_error as ranking_error
ranking_error[ii] = FP_num/rank
# ------------------------#
# Current sorting error
# ------------------------#
current_sorting_error = torch.sum(fg_relations * (1 - fg_targets)) / rank_pos
# -------------------------------------------------------------------#
# Find examples ranking higher and targets larger than ii
# -------------------------------------------------------------------#
iou_relations = (fg_targets >= fg_targets[ii])
target_sorted_order = iou_relations * fg_relations
# The rank of ii among positives in sorted order
rank_pos_target = torch.sum(target_sorted_order)
# ------------------------------------------------------#
# Target sorting error (target ranking error is 0)
# ------------------------------------------------------#
# Since target_ranking_error is always 0, target_sorting_error is also the target_error
target_sorting_error = torch.sum(target_sorted_order * (1 - fg_targets)) / rank_pos_target
# ------------------#
# Sorting error
# ------------------#
sorting_error[ii] = current_sorting_error - target_sorting_error
# -------------------------------------#
# Identity Update for Ranking Error
# -------------------------------------#
if FP_num > eps:
# For ii the update is the ranking error
fg_grad[ii] -= ranking_error[ii]
# For negatives, distribute error via ranking pmf (i.e. bg_relations/FP_num)
relevant_bg_grad += (bg_relations * (ranking_error[ii]/FP_num))
# --------------------------------------------------------------#
# Find examples ranking higher but targets smaller than ii
# --------------------------------------------------------------#
# Find the positives that are misranked (the ones with smaller IoU but larger logits)
missorted_examples = (~ iou_relations) * fg_relations
# Denominotor of sorting pmf
sorting_pmf_denom = torch.sum(missorted_examples)
# -------------------------------------#
# Identity Update for Sorting Error
# -------------------------------------#
if sorting_pmf_denom > eps:
# For ii the update is the sorting error
fg_grad[ii] -= sorting_error[ii]
# For positives, distribute error via sorting pmf (i.e. missorted_examples/sorting_pmf_denom)
fg_grad += (missorted_examples * (sorting_error[ii] / sorting_pmf_denom))
# Normalize gradients by number of positives
classification_grads[fg_labels] = (fg_grad/fg_num)
classification_grads[relevant_bg_labels] = (relevant_bg_grad/fg_num)
ctx.save_for_backward(classification_grads)
return ranking_error.mean(), sorting_error.mean()
@staticmethod
def backward(ctx, out_grad1, out_grad2):
g1, = ctx.saved_tensors
return g1*out_grad1, None, None, None