《Segmenter》-- 代码笔记

论文地址:

代码:

VisionTransformer(ViT:encoder)

class VisionTransformer(BaseModule):
    """Vision Transformer."""
    def forward(self, inputs):  # eg: inputs:(1,3,512,512)
        B = inputs.shape[0]

        x, hw_shape = self.patch_embed(inputs)  # x=(1,1024,192) (32,32)

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token.expand(B, -1, -1) 
        # self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims=768))
        # (1,1,192)
        x = torch.cat((cls_tokens, x), dim=1)   # (1,1025,192)
        x = self._pos_embeding(x, hw_shape, self.pos_embed) # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dims))
        # 将位置信息嵌入token;
        if not self.with_cls_token:
            # Remove class token for transformer encoder input
            x = x[:, 1:]

        outs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)    # 第一次循环: TransformerEncoderLayer
            if i == len(self.layers) - 1:
                if self.final_norm:
                    x = self.norm1(x)
            if i in self.out_indices:   # out_indices = [11]
                if self.with_cls_token:
                    # Remove class token and reshape token for decoder head
                    out = x[:, 1:]
                else:
                    out = x
                B, _, C = out.shape
                out = out.reshape(B, hw_shape[0], hw_shape[1],
                                  C).permute(0, 3, 1, 2).contiguous()
                if self.output_cls_token:
                    out = [out, x[:, 0]]
                outs.append(out)

        return tuple(outs)  # (1,192,32,32)

TransformerEncoder

class TransformerEncoderLayer(BaseModule):
    """Implements one encoder layer in Vision Transformer."""
    def forward(self, x):   # x = (1,1025,192)

        def _inner_forward(x):
            x = self.attn(self.norm1(x), identity=x)    # 尺寸不变
            x = self.ffn(self.norm2(x), identity=x)  # x = (1,1025,192)
            return x

        if self.with_cp and x.requires_grad:    # cp:check_point
            x = cp.checkpoint(_inner_forward, x)
        else:
            x = _inner_forward(x)
        return x

_pos_embeding

def _pos_embeding(self, patched_img, hw_shape, pos_embed):
        """Positiong embeding method.

PatchEmbed

class PatchEmbed(BaseModule):
    def forward(self, x):
        if self.adap_padding:
            x = self.adap_padding(x)

        x = self.projection(x)  # Conv2d(3,192,k=16,s=16); x=(1,192,32,32)
        out_size = (x.shape[2], x.shape[3]) # 32,32
        x = x.flatten(2).transpose(1, 2)    # x=(1,1024,192)
        if self.norm is not None:
            x = self.norm(x)
        return x, out_size

segmenter-mask-transformer head:decoder

class SegmenterMaskTransformerHead(BaseDecodeHead):
    def forward(self, inputs):  # inputs = (1,192,32,32)
        x = self._transform_inputs(inputs)  # segmenter中设置的此参数为None,所以x = (1,192,32,32);
        b, c, h, w = x.shape    # 1,192,32,32
        x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)   # x = (1,1024,192)

        x = self.dec_proj(x)    # .dec_proj = nn.Linear(in_channels, embed_dims) x = (1,1024,192)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1) #(1,2,192):self = nn.Parameter(torch.randn(1, self.num_classes, embed_dims))
        x = torch.cat((x, cls_emb), 1)  # x = (1,1026,192)
        for layer in self.layers:   # 这些layers都是transformer-encoder: 进过注意力使得这些token都增加了全局信息;
            x = layer(x)
        x = self.decoder_norm(x)

        patches = self.patch_proj(x[:, :-self.num_classes]) # .patch_proj = nn.Linear(纬度不变) patches=(1,1024,192)
        cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])  #(1,2,192)

        patches = F.normalize(patches, dim=2, p=2)
        cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)

        masks = patches @ cls_seg_feat.transpose(1, 2)  # (1,1024,2)
        masks = self.mask_norm(masks)
        masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)

        return masks    # (1,2,32,32)

你可能感兴趣的:(#,语义分割,深度学习,transformer)