有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)
位置:models/transformer.py/Transformer类
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
这个辅助函数,遍历当前被调用的模型中所有的参数,当这个参数的维度大于1时,会对当前参数使用Xavier均匀初始化方法进行初始化
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def forward(self, src, mask, query_embed, pos_embed):
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
mask = mask.flatten(1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)