【clip源码阅读】VisionTransformer

lib/python3.8/site-packages/clip/model.py#L206

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        # x: 输入原始图像,经过缩放,统一的大小为 224*224
        
        # 一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14
        # 将每一个patch的矩阵拉伸成为一个1维向量,从而获得了近似词向量堆叠的效果。上一步得道的14 x 14的patch就转换为长度为196的向量
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        
        # 每个patch拉伸为1*196
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        
        # 加上class embedding变为1*197的向量
        # class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。
        # 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        
        # 加上1*197的position embedding
        # pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中
        # 它的加入类似于全链接网络和卷积的bias
        x = x + self.positional_embedding.to(x.dtype)
        
        # pre layer norm
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        
        # post layer norm
        x = self.ln_post(x[:, 0, :])
        # 由于增加的class_embedding是一个可以学习的参数,经过网络的不断训练
        # 最终以输出向量的第一个维度的输出来决定最后的输出类别
        # [bs, n_patch=257, dim=1024] -> [bs, dim=1024]      

        if self.proj is not None:
            x = x @ self.proj

        return x

你可能感兴趣的:(clip,clip,transformer)