论文地址:
代码:
VisionTransformer(ViT:encoder)
class VisionTransformer(BaseModule):
"""Vision Transformer."""
def forward(self, inputs): # eg: inputs:(1,3,512,512)
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs) # x=(1,1024,192) (32,32)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
# self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims=768))
# (1,1,192)
x = torch.cat((cls_tokens, x), dim=1) # (1,1025,192)
x = self._pos_embeding(x, hw_shape, self.pos_embed) # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dims))
# 将位置信息嵌入token;
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x) # 第一次循环: TransformerEncoderLayer
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices: # out_indices = [11]
if self.with_cls_token:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
return tuple(outs) # (1,192,32,32)
TransformerEncoder
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer."""
def forward(self, x): # x = (1,1025,192)
def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x) # 尺寸不变
x = self.ffn(self.norm2(x), identity=x) # x = (1,1025,192)
return x
if self.with_cp and x.requires_grad: # cp:check_point
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
_pos_embeding
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positiong embeding method.
PatchEmbed
class PatchEmbed(BaseModule):
def forward(self, x):
if self.adap_padding:
x = self.adap_padding(x)
x = self.projection(x) # Conv2d(3,192,k=16,s=16); x=(1,192,32,32)
out_size = (x.shape[2], x.shape[3]) # 32,32
x = x.flatten(2).transpose(1, 2) # x=(1,1024,192)
if self.norm is not None:
x = self.norm(x)
return x, out_size
segmenter-mask-transformer head:decoder
class SegmenterMaskTransformerHead(BaseDecodeHead):
def forward(self, inputs): # inputs = (1,192,32,32)
x = self._transform_inputs(inputs) # segmenter中设置的此参数为None,所以x = (1,192,32,32);
b, c, h, w = x.shape # 1,192,32,32
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) # x = (1,1024,192)
x = self.dec_proj(x) # .dec_proj = nn.Linear(in_channels, embed_dims) x = (1,1024,192)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1) #(1,2,192):self = nn.Parameter(torch.randn(1, self.num_classes, embed_dims))
x = torch.cat((x, cls_emb), 1) # x = (1,1026,192)
for layer in self.layers: # 这些layers都是transformer-encoder: 进过注意力使得这些token都增加了全局信息;
x = layer(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes]) # .patch_proj = nn.Linear(纬度不变) patches=(1,1024,192)
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) #(1,2,192)
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2) # (1,1024,2)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
return masks # (1,2,32,32)