源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法

Date: 2020/07/17

Coder: CW

Foreword:

本文将对 loss函数的实现进行解析,由于 DETR 是预测结果是集合的形式,因此在计算loss的时候有个关键的前置步骤就是将预测结果和GT进行匹配,这里的GT类别是不包括背景的,未被匹配的预测结果就自动被归类为背景。匹配使用的是匈牙利算法,该算法主要用于解决与二分图匹配相关的问题,对这部分感兴趣的朋友们可以参考下这篇文:匈牙利算法


Outline

I. Loss Function

    i). 分类loss

    ii). 回归loss

II. Hungarian algorithm(匈牙利算法)


Loss Function

先来看看与loss函数相关的一些参数:matcher就是将预测结果与GT进行匹配的匈牙利算法,这部分的实现会在下一节解析。weight_dict是为各部分loss设置的权重,主要包括分类与回归损失,分类使用的是交叉熵损失,而回归损失包括bbox的 L1 Loss(计算x、y、w、h的绝对值误差)与 GIoU Loss。若设置了masks参数,则代表分割任务,那么还需加入对应的loss类型。另外,若设置了aux_loss,即代表需要计算解码器中间层预测结果对应的loss,那么也要设置对应的loss权重。

与loss函数实现相关的初始化参数

loss函数是通过实例化SetCriterion对象来构建。

构建loss函数

losses变量指示需要计算哪些类型的loss,其中cardinality仅用作log,并不涉及反向传播梯度。

loss_cardinality

可以先来看下SetCriterion这个类的doc string,了解下各部分参数的意义。

SetCriterion(i)

CW 也作了对应的注释:

SetCriterion(ii)

接下来看下其前向过程,从而知悉loss的计算。

这里一定要先搞清楚模型的输出(outputs)和GT(targets)的形式,对于outputs可参考CW在下图中的注释;而targets是一个包含多个dict的list,长度与batch size相等,其中每个dict的形式如COCO数据集的标注,具体可参考该系列的第二篇文章: 源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理 中的数据处理部分。

SetCriterion(iii)

如CW在前言部分所述,计算loss的一个关键的前置步骤就是将模型输出的预测结果与GT进行匹配,对应下图中self.matcher()的部分,返回的indices的形式已在注释中说明。

SetCriterion(iv)

接下来是计算各种类型的loss,并将对应结果存到一个dict中(如下图losses变量),self.get_loss()方法返回loss计算结果。

SetCriterion(v)
SetCriterion(vi)

get_loss方法中并不涉及具体loss的计算,其仅仅是将不同类型的loss计算映射到对应的方法,最后将计算结果返回。

get_loss

接下来,我们就对分类和回归损失的计算过程分别进行解析。

i). 分类loss

首先说明下,doc string里写的是NLL Loss,但实际调用的是CE Loss,这是因为在Pytorch实现中,CE Loss实质上就是将Log-Softmax操作和NLL Loss封装在了一起,如果直接使用NLL Loss,那么需要先对预测结果作Log-Softmax操作,而使用CELoss则直接免去了这一步。

loss_labels(i)

其次,要理解红框部分的_get_src_permutation_idx()在做什么。输入参数indices是匹配的预测(query)索引与GT的索引,其形式在上述SetCriterion(iv)图中注释已有说明。该方法返回一个tuple,代表所有匹配的预测结果的batch index(在当前batch中属于第几张图像)和 query index(图像中的第几个query对象)。

_get_src_permutation_idx

类似地,我们可以获得当前batch中所有匹配的GT所属的类别(target_classes_o),然后通过src_logitstarget_classes_o就可以设置预测结果对应的GT了,这就是下图中的target_classes。target_classes的shape和src_logits一致,代表每个query objects对应的GT,首先将它们全部初始化为背景,然后根据匹配的索引(idx)设置匹配的GT(target_classes_p)类别。

loss_labels(ii) 

“热身活动”做完后,终于可以开始计算loss了,注意在使用Pytorch的交叉熵损失时,需要将预测类别的那个维度转换到通道这个维度上(dim1)。

loss_labels(iii) 

另外,class_error计算的是Top-1精度(百分数),即预测概率最大的那个类别与对应被分配的GT类别是否一致,这部分仅用于log,并不参与模型训练。

accuracy

ii). 回归loss

回归loss的计算包括预测框与GT的中心点和宽高的L1 loss以及GIoU loss

注意在下图注释中,num_matched_queries1+num_matched_queries2+..., 和 num_matched_objs1+num_matched_objs2+... 是相等的,在前面 SetCriterion(iv) 图中matcher的返回结果注释中有说明。

loss_boxes(i)

以下就是loss的计算。注意下 reduction 参数,若不显式进行设置,在Pytorch的实现中默认是'mean',即返回所有涉及误差计算的元素的均值。

loss_boxes(ii)

另外,在计算GIoU loss时,使用了torch.diag()获取对角线元素,这是因为generalized_box_iou()方法返回的是所有预测框与所有GT的GIoU,比如预测框有N个,GT有M个,那么返回结果就是NxM个GIoU。而如 loss_boxes(i) 图中所示,我们预先对匹配的预测框和GT进行了排列,即N个预测框中的第1个匹配M个GT中的第1个,N中第2个匹配M中第2个,..,N中第i个匹配M中第i个,于是我们要取相互匹配的那一项来计算loss。

generalized_box_iou(i)
generalized_box_iou(ii)

Hungarian algorithm(匈牙利算法)

build_matcher()方法返回HungarianMatcher对象,其实现了匈牙利算法,在这里用于预测集(prediction set)和GT的匹配,最终匹配方案是选取“loss总和”最小的分配方式。注意CW对loss总和这几个字用了引号,其与loss函数中计算的loss并不完全一致,在这里是作为度量cost/metric)的角色,度量的值决定了匹配的结果,接下来我们看代码实现就会一清二楚。

build_matcher

如doc string所述,GT是不包含背景类的,通常预测集中的物体数量(默认为100)会比图像中实际存在的目标数量多,匈牙利算法按1对1的方式进行匹配,没有被匹配到的预测物体就自动被归类为背景(non-objects)。

HungarianMatcher(i)

以下cost_xx代表各类型loss的相对权重,在匈牙利算法中,描述为各种度量的相对权重会更合适,因此,这里命名使用的是'cost'。

HungarianMatcher(ii)

现在来看看前向过程,注意这里是不需要梯度的

HungarianMatcher(iii)

首先将预测结果和GT进行reshape,并对应起来,方便进行计算。

HungarianMatcher(iv)

注:以上tgt_bbox等式右边的torch.cat()方法中应加上参数dim=0

然后就可以对各种度量(各类型loss)进行计算。

如代码所示,这里的cost与上一节解析的loss并不完全一样,比如对于分类来说,loss计算使用的是交叉熵,而这里为了更加简便,直接采用1减去预测概率的形式,同时由于1是常数,于是作者甚至连1都省去了,有够机智(懒)的...

HungarianMatcher(v)

另外,在计算bbox的L1误差时,使用了torch.cdist(),其中设置参数p=1代表L1范式(默认是p=2,即L2范式),这个方法会对每对预测框与GT都进行误差计算:比如预测框有N个,GT有M个,结果就会有NxM个值。

接着对各部分度量加权求和,得到一个总度量。然后,统计当前batch中每张图像的GT数量,这个操作是为什么呢?接着看,你会发现这招很妙!

C.split()在最后一维按各张图像的目标数量进行分割,这样就可以在各图像中将预测结果与GT进行匹配了。

HungarianMatcher(vi)

匹配方法使用的是scipy优化模块中的linear_sum_assignment(),其输入是二分图的度量矩阵,该方法是计算这个二分图度量矩阵的最小权重分配方式,返回的是匹配方案对应的矩阵行索引和列索引。

linear_sum_assignment

结尾日常吹水

吾以为,loss函数的设计是DL项目中最重要的部分之一。CW每次看项目的源码时,最打起精神的就是这一part了。

从数学的角度来看,DL本质上是一个优化问题,loss是模型学习目标在数学上的表达形式,我们期望模型朝着loss最小的方向发展,因此,loss函数的设计关系到优化的可行性及难易程度,可谓成败之关键。因此,这部分其实很考验炼丹师的功力,也最能体现一个人考虑和解决问题的思想。

如今,我们是站在前人(一堆大佬,不,是巨佬!)的肩膀上,日常无脑地来来去去都用那几种loss,真是幸福的新生儿呐!

你可能感兴趣的:(源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法)