DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非极大值抑制(NMS)等复杂后处理步骤。DETR使用卷积神经网络提取图像特征,并将其通过位置编码转换为输入序列,送入Transformer的Encoder-Decoder结构。Decoder通过固定数量的目标查询(Object Queries),预测类别和边界框位置。DETR创新性地引入匈牙利算法进行二分图匹配,确保预测与真实值的唯一对应关系,且采用交叉熵损失和L1-GIoU损失进行优化。在COCO数据集上的实验表明,DETR在大目标检测中表现优异,并能灵活迁移到其他任务,如全景分割。
DETR (DEtection TRansformer) is an end-to-end target detection method based on Transformer architecture proposed by Facebook AI. By modeling object detection as a set prediction problem, it eliminates complex post-processing steps such as anchor frame design and non-maximum suppression (NMS). DETR uses convolutional neural networks to extract image features and convert them via positional encoding into input sequences that feed into Transformer’s Encoder-Decoder structure. Decoder predicts categories and bounding box positions with a fixed number of Object Queries. DETR innovates by introducing the Hungarian algorithm for bipartite graph matching to ensure a unique relationship between the prediction and the true value, and optimizes with cross-entropy losses and L1-GIoU losses. Experiments on the COCO dataset show that DETR performs well in large target detection and can be flexibly migrated to other tasks, such as panoramic segmentation.
DETR(DEtection TRansformer)是由Facebook AI在2020年提出的一种基于Transformer架构的端到端目标检测方法。与传统的目标检测方法(如Faster R-CNN、YOLO等)不同,DETR直接将目标检测建模为一个集合预测问题,摆脱了锚框设计和复杂的后处理(如NMS)。结果在 COCO 数据集上效果与 Faster RCNN 相当,在大目标上效果比 Faster RCNN 好,且可以很容易地将 DETR 迁移到其他任务例如全景分割。
简单来说,就是通过CNN提取图像特征(通常 Backbone 的输出通道为 2048,图像高和宽都变为了 1/32),并经过input embedding+positional encoding操作转换为图像序列(如下图所说,就是类似[N, HW, C]的序列)作为transformer encoder的输入,得到了编码后的图像序列,在图像序列的帮助下,将object queries(下图中说的是固定数量的可学习的位置embeddings)转换/预测为固定数量的类别+bbox预测。相当于Transformer本质上起了一个序列转换的作用。
下图为DETR的详细结构:
DETR中的encoder-decoder与transformer中的encoder-decoder对比:
匈牙利算法是用于解决二分图匹配的问题,即将Ground Truth的K个bbox和预测出的100个bbox作为二分图的两个集合,匈牙利算法的目标就是找到最大匹配,即在二分图中最多能找到多少条没有公共端点的边。匈牙利算法的输入就是每条边的cost 矩阵
思考:
DETR 预测了一组固定大小的 N = 100 个边界框,这比图像中感兴趣的对象的实际数量大得多。怎么样来计算损失呢?或者说预测出来的框我们怎么知道对应哪一个 ground-truth 的框呢?
为了解决这个问题,第一步是将 ground-truth 也扩展成 N = 100 个检测框。使用了一个额外的特殊类标签 ϕ \phiϕ 来表示在未检测到任何对象,或者认为是背景类别。这样预测和真实都是两个100 个元素的集合了。这时候采用匈牙利算法进行二分图匹配,即对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。
σ ^ = arg min G ∈ G N ∑ i N L m a t c h ( y i , y ^ σ ( i ) ) \hat{\sigma}=\arg\min_{\mathrm{G\in G_N}}\sum_{\mathrm{i}}^{\mathrm{N}}\mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right) σ^=argG∈GNmini∑NLmatch(yi,y^σ(i))
L m a t c h ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L b o x ( b i , b ^ σ ( i ) ) \mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right)=-1_{\{\mathrm{c_i}\neq\varnothing\}}\hat{\mathrm{p}}_{\mathrm{\sigma(i)}}\left(\mathrm{c_i}\right)+1_{\{\mathrm{c_i}\neq\varnothing\}}\mathcal{L}_{\mathrm{box}}\left(\mathrm{b_i},\hat{\mathrm{b}}_{\mathrm{\sigma(i)}}\right) Lmatch(yi,y^σ(i))=−1{ci=∅}p^σ(i)(ci)+1{ci=∅}Lbox(bi,b^σ(i))
对于那些不是背景的,获得其对应的预测是目标类别的概率,然后用框损失减去预测类别概率。这也就是说不仅框要近,类别也要基本一致,是最好的。经过匈牙利算法之后,我们就得到了 ground truth 和预测目标框之间的一一对应关系。然后就可以计算损失函数了。
下面是利用pytorch实现DETR的代码:
位置编码部分:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
用于为序列数据(如Transformer中的输入)添加位置信息。位置编码帮助模型保留序列中元素的位置信息,这是因为Transformer模型本身不具备位置信息感知能力。
使用正弦和余弦函数优点:
优点:
正弦和余弦具有周期性和平滑性;
不同维度具有不同频率,编码了多尺度的位置信息。
作用:保留序列的位置信息,使模型能够感知数据的顺序。
编码可视化结果:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
pe = PositionalEncoding(d_model=16, max_len=100)
x = torch.zeros(100, 1, 16)
encoded = pe(x).squeeze(1).detach().numpy()
plt.figure(figsize=(10, 5))
plt.imshow(encoded, aspect='auto', cmap='viridis')
plt.colorbar(label='Encoding Value')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Visualization')
plt.show()
同一位置的编码:
值的分布(正弦和余弦的相互作用)保证了每个位置在多维空间中具有唯一性。
时间步的相对差异:
相邻位置(如第1和第2位置)在高维上的值差异较大,这为模型提供了感知时间步变化的能力。
encoder-decoder:
class Transformer(nn.Module):
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
super().__init__()
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)
self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
memory = self.encoder(src, mask=src_mask)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
return output
DETR模型:
# DETR模型
class DETR(nn.Module):
def __init__(self, num_classes, num_queries, backbone='resnet50'):
super().__init__()
self.num_queries = num_queries
# Backbone
self.backbone = models.resnet50(pretrained=True)
self.conv = nn.Conv2d(2048, 256, kernel_size=1)
# Transformer
self.transformer = Transformer(d_model=256)
self.query_embed = nn.Embedding(num_queries, 256)
self.positional_encoding = PositionalEncoding(256)
# Prediction heads
self.class_embed = nn.Linear(256, num_classes + 1) # +1 for no-object class
self.bbox_embed = nn.Linear(256, 4)
def forward(self, images):
# Feature extraction
features = self.backbone(images)
features = self.conv(features)
h, w = features.shape[-2:]
# Flatten and add positional encoding
src = features.flatten(2).permute(2, 0, 1) # (HW, N, C)
src = self.positional_encoding(src)
# Query embedding
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, images.size(0), 1) # (num_queries, N, C)
# Transformer
hs = self.transformer(src, query_embed)
# Prediction
outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid() # Normalized to [0, 1]
return {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}
DETR通过Transformer实现端到端的目标检测,无需(如NMS)复杂的后处理。相比传统检测器,DETR具有简洁的架构和强大的全局建模能力,但训练时对数据和计算资源的需求较高。
DETR简化了目标检测的流程,摒弃了传统检测器中繁琐的锚框设计和后处理步骤,架构更简洁,且依托于Transformer的全局建模能力,在捕捉长距离特征关系方面表现出色。相比传统方法,DETR在目标数量固定的场景下,能够更高效地处理目标检测任务。其优点包括易迁移、多任务适用性和端到端优化能力,但其劣势在于训练时间较长、计算资源消耗较大,尤其是在小目标检测和训练数据量不足的情况下效果略显不足。