Vision Transformer(1):ViT源码逐行阅读解析

 上图是Vision Transformer原文的模型结构展示,可以看到模型包含了几个核心模块:

 Vision Transformer:

        1. Embedding模块

        2.Transformer Encoder模块

                2.1 NormLayer ( × depth )

                        2.1.1 Multi-Head Attention层


                        2.1.2 MLP多层感知器

        3.MLP-Head 模块映射为类别


一、ViT & Embedding

假设训练数据维度为(64, 3, 224, 224),意味着有64张三通道的224*224的图像。



class ViT(nn.Module):
        *: input data
        image_size: 等边图像尺寸
        patch_size: patch的尺寸
        num_classes: 分类类别
        dim: 为每一个patch编码的长度
        depth: Encoder的深度,也就是连接encoder的数目
        heads: 多头注意力中头的数目
        mlp_dim: 多层感知器中隐含层的维度
        pool: 使用cls token还是使用均值池化
        channel: 图像的通道数
        dim_head: 注意力机制中一个头的输入维度
        dropout: NormLayer中dropout的参数比例
        emb_dropout: Embedding中的dropout比例
    :return 分类结果(64, 2)
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # image_size就是每一张图像的长和宽,通过pair函数便捷明了的表现
        # patch_size就是图像的每一个patch的长和宽
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        # 保证图像可以整除为若干个patch
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算出每一张图片会被切割为多少个patch
        # 假设输入维度(64, 3, 224, 224), num_patches = 49
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 每一个patch数组大小, patch_dim = 3*32*32=3072
        patch_dim = channels * patch_height * patch_width
        # cls就是分类的Token, mean就是均值池化
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        # embeding操作:假设输入维度(64, 3, 224, 224),那么经过Rearange层后变成了(64, 7*7=49, 32*32*3=3072)
        self.to_patch_embedding = nn.Sequential(
            # 将图片分割为b*h*w个三通道patch,b表示输入图像数量
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            # 经过线性全连接后,维度变成(64, 49, 128)
            nn.Linear(patch_dim, dim),
        # dim张图像,每张图像需要num_patches个向量进行编码
        # 位置编码(1, 50, 128) 本应该为49,但因为cls表示类别需要增加一个
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # CLS类别token,(1, 1, 128)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 设置dropout
        self.dropout = nn.Dropout(emb_dropout)
        # 初始化Transformer
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        # pool默认是cls进行分类
        self.pool = pool
        self.to_latent = nn.Identity()
        # 多层感知用于将最终特征映射为2个类别
        self.mlp_head = nn.Sequential(
            nn.Linear(dim, num_classes)

    def forward(self, img):
        # 第一步,原始图像ebedding,进行了图像切割以及线性变换,变成x->(64, 49, 128)
        x = self.to_patch_embedding(img)
        # 得到原始图像数目和单图像的patches数量, b=64, n=49
        b, n, _ = x.shape
        # (1, 1, 128) -> (64, 1, 128) 为每一张图像设置一个cls的token
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将cls token加入到数据中 -> (64, 50, 128)
        x =, x), dim=1)
        # x(64, 50, 128)添加位置编码(1, 50, 128)
        x += self.pos_embedding[:, :(n + 1)]
        # 经过dropout层防止过拟合
        x = self.dropout(x)

        x = self.transformer(x)
        # 进行均值池化
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        # 最终进行分类映射
        return self.mlp_head(x)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        # 设定depth个encoder相连,并添加残差结构
        self.layers = nn.ModuleList([])
        for _ in range(depth):
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
    def forward(self, x):
        # 每次取出包含Norm-attention和Norm-mlp这两个的ModuleList,实现残差结构
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class PreNorm(nn.Module):
    :param  dim 输入维度
            fn 前馈网络层,选择Multi-Head Attn和MLP二者之一
    def __init__(self, dim, fn):
        # LayerNorm: ( a - mean(last 2 dim) ) / sqrt( var(last 2 dim) )
        # 数据归一化的输入维度设定,以及保存前馈层
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    # 前向传播就是将数据归一化后传递给前馈层
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__() = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Linear(hidden_dim, dim),
    def forward(self, x):


class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        inner_dim = heads * dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        # 表示1/(sqrt(dim_head))用于消除误差,保证方差为1,避免向量内积过大导致的softmax将许多输出置0的情况
        # 可以看原文《attention is all you need》中关于Scale Dot-Product Attention如何抑制内积过大
        self.scale = dim_head ** -0.5
        # dim =  > 0 时,表示mask第d维度,对相同的第d维度,进行softmax
        # dim =  < 0 时,表示mask倒数第d维度,对相同的倒数第d维度,进行softmax
        self.attend = nn.Softmax(dim = -1)
        # 生成qkv矩阵,三个矩阵被放在一起,后续会被分开
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 如果是多头注意力机制则需要进行全连接和防止过拟合,否则输出不做更改
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 分割成q、k、v三个矩阵
        # qkv为 inner_dim * 3,其中inner_dim = heads * dim_head
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # qkv的维度是(3, inner_dim = heads * dim_head)
        # 'b n (h d) -> b h n d' 重新按思路分离出8个头,一共8组q,k,v矩阵
        # rearrange后维度变成 (3, heads, dim, dim_head)
        # 经过map后,q、k、v维度变成(1, heads, dim, dim_head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # query * key 得到对value的注意力预测,并通过向量内积缩放防止softmax无效化部分参数
        # heads * dim * dim
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # 对最后一个维度进行softmax后得到预测的概率值
        attn = self.attend(dots)
        # 乘积得到预测结果
        # out -> heads * dim * dim_head
        out = torch.matmul(attn, v)
        # 重组张量,将heads维度重新还原
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


