DN-DETR源码讲解

文章目录

  • 一:创新点
  • 二:源码分析
      • DAB-DETR主模块
      • Transformer
      • Loss计算细节

一:创新点

DN-DETR中的DN指的是denoising,即“去噪”,是一种训练时加快收敛速度的trick。作者将网络拆分为了Denoising Part和Matching Part,只有在训练的时才有Denoising Part,inference时去除。

DN-DETR的主框架和Conditional DETR、DAB DETR完全类似,对它们还不熟悉的小伙伴可以看Conditional DETR和DAB DETR这两篇文章。下面展示一下整体网络图:

DN-DETR源码讲解_第1张图片

二:源码分析

DAB-DETR主模块

  • def init()

DN-DETR源码讲解_第2张图片DN-DETR源码讲解_第3张图片

  • def forward()

DN-DETR源码讲解_第4张图片DN-DETR源码讲解_第5张图片

init初始化时生成了[91 + 1, 256]的self.label_enc和[10, 4]的self.refpoint_embed,分别是label词缀表(最后一维其实是初始化tgt)和refpoint的初始化。

forward老生常谈,唯一的不同就是多了prepare_for_dn处理target数据,和dn post process对输出结果作拆分(将[3 2 30 91] 和 [3 2 30 4]分别拆为 [3 2 10 91]、[3 2 10 4]和[3 2 20 91]、[3 2 20 4],前两个作为真正的output和refpoints,后两个作为去噪后的labels和boxes结果扔进mask_dict中,用来计算去噪损失)。最后返回的是[3 2 10 91]、[3 2 10 4]和mask_dict。

让我们看一下prepare_for_dn函数源码:

DN-DETR源码讲解_第6张图片
DN-DETR源码讲解_第7张图片DN-DETR源码讲解_第8张图片DN-DETR源码讲解_第9张图片DN-DETR源码讲解_第10张图片

该函数的功能是由target中真实label和boxes,生成几组group的噪声target,然后拼接在一起。对于label是随机flip,而boxes则是改变center和w、h。得到[2, 20, 256]的input_label_embed和[2, 20, 4]的boxes。注意!最后还添加了[2 10 256]的tgt和[2 10 4]的refpoint_embed

还有一个重点,就是attn_mask。作者在论文中提出如下见解:

Therefore, our attention mask is to make sure the matching part cannot see the denoising part and the denoising groups
cannot see each other as shown in Fig. 4.

翻译一下就是在decoder时防止泄题,denoising part中各个group之间不能互相看到,matching part中的query不能看到denoising part中的groups。而denoising part中的groups看到matching part也没事,因为它们需要学习,里面不包含“答案”,

最后输出[2 30 256]的input_query_label,[2 30 4]的input_query_bbox,[30 30]的attn_mask,和包含了大量原始target和索引的字典(其中内容请看源码中的注释,用来最后计算Loss用的)。将它们和src等输入到transformer中,下面看transformer模块:

Transformer

DN-DETR源码讲解_第11张图片中规中矩,encoder和decoder中的细节就不讲了,和DAB-DETR一字不差。最后输出[3 2 30 256]的hs和[3 2 30 4]的references。最后我们再看一下Loss的计算细节:

Loss计算细节

  • engine.py
    DN-DETR源码讲解_第12张图片
  • criterion中的 forward(output, target)

DN-DETR源码讲解_第13张图片DN-DETR源码讲解_第14张图片DN-DETR源码讲解_第15张图片

Loss计算和DETR常规计算一样,只多了dn loss computation,计算去噪损失,下面是 dn_losses = compute_dn_loss(mask_dict, self.training, aux_num, self.focal_alpha)实现源码:

  • compute_dn_loss

DN-DETR源码讲解_第16张图片
DN-DETR源码讲解_第17张图片

通过prepare_for_loss对mask_dict进行处理,将[3 2 20 256]的output_known_coord和[3 2 20 4]的output_known_class中多余的zero行去除,提取出group_num✖label_num个真正的去噪target,文中是35(5✖7=35)个,该函数输出是[3 35 4]和[3 35 91]。

最后只对最后一维的结果计算去噪Loss,敲重点,这里计算[35 91]的label损失用的是focal函数,实现细节略。

最后展现一下prepare_for_losstgt_loss_labelstgt_loss_boxes

  • prepare_for_loss
    DN-DETR源码讲解_第18张图片
  • tgt_loss_labels

DN-DETR源码讲解_第19张图片

  • tgt_loss_boxes

DN-DETR源码讲解_第20张图片


  至此我对DN-DETR源码中全部的流程与细节,进行了深度讲解,希望对大家有所帮助,有不懂的地方或者建议,欢迎大家在下方留言评论。

我是努力在CV泥潭中摸爬滚打的江南咸鱼,我们一起努力,不留遗憾!

你可能感兴趣的:(Transformer,深度学习,计算机视觉,神经网络)