VIT(vision transformer) 模型 Pytorch实现 解析 rwightman版

version_transformer

  • 源码
  • 解析
    • 随机路径失活
    • 输入序列化
    • 注意力机制实现
        • 前向传播
    • 多层感知机:
    • 注意力模块:
      • 前向传播
    • VIT搭建
        • Representation layer
        • 分类头
        • 权重初始化
        • _初始化权重:
      • 向前传播
      • 内嵌前向特征函数

源码

这是一个针对ision transformer 模型的解析笔记。关于模型的框架可以参考霹雳吧啦的图如下。
VIT rwightman版源码
代码较多,本文只抓点分析,不一一张贴。自行对照
VIT(vision transformer) 模型 Pytorch实现 解析 rwightman版_第1张图片

解析

随机路径失活

这里code里面用了一个DropPath类来实现。考虑是继承nn.module来实现随机失活。根据函数名,推测这里应该是随机失活分支。其中前向函数为

keep_prob = 1 - drop_prob # 分别表示 生存概率和失活概率。两者互不
shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device),

random_tensor.floor_()  # binarize

keep_prob 表示生存概率,当失活概率drop_prob小于0.5,即生存概率较大,此时random_tensor中每个随机向量都加上一个大于0.5的值。在经过floor_()向下取整时将保留生存概率>0.5 的元素,如此实现失活。这里说明一点,torch.rand()产生的张量元素随机<1,如此有随机.

输入序列化

为了在Transfromer上对图片数据进行训练,需要把图片处理成序列。好像是图片利用14宫格把224x224的图片分成768张小图。再把这些小图送进编码器进行处理。具体实现是利用16x16 s16的卷积把224x224的图片进行处理、展平、移项。

class PatchEmbed(nn.Module):
	self.num_patches = self.grid_size[0] * self.grid_size[1]	图片处理成num_patches个序列。
	self.num_patches = self.grid_size[0] * self.grid_size[1]
	self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
	self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

并通过一个卷积进行影射。输入in_c=3,对应图片通道, embed_dim=小图的个数为16x16即序列的维度kernel_size=stride=16,是为了根据outSize=(inSize-k+2p)/s+1=(224-16+0)/16+1=14 满足输出小图14x14的相应设计

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

展平为 [B, C, HW]形式是因为,VIT 本就是类似把图片当作文本处理,故把高宽展平为一个维度有利与计算。B:batchSize,C:channeel=embed_dim=768,HW=196

这里我认为768更像是通过卷积的方式,把每个小图的像素存入768个元素中,196x768更像是196个小图的数据。

注意力机制实现

class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim(维度)
                 num_heads=8, #多头编码的数量
                 qkv_bias=False, 
                 qk_scale=None, # qk相乘的大小
                 attn_drop_ratio=0., # 多头编码合并后做失活
                 proj_drop_ratio=0.): # MLP 后做失活
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads # 可看作取整数部分,向下取整
        self.scale = qk_scale or head_dim ** -0.5 # 当没有赋值时, -0.5次方运算。
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)# 等于3(q k v) dim节点数的linear
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim) # dim=768,维度不变
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):# 197x768
        # [batch_size, num_patches + 114x14+1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]X3->q\k\v
        # 前面3dim长度的linear同时影射得到 qkV的向量。经过上面处理后第0个维度为3,分别表示qkv。
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] # 相乘是的维度处理axb*bxa=axa
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C) # reshsape拼接
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

为了生成Q、K、V三个向量需要三个线性Linear层。这里dim*3增加了节点数,代替了三个线性层,得到QKV. dim=768表示序列长度

前向传播

def forward(self, x):# 197x768
        # [batch_size, num_patches + 114x14+1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] 
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]X3->q\k\v
        # 前面3dim长度的linear同时影射得到 qkV的向量。经过上面处理后第0个维度为3,分别表示qkv。
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] # 相乘是的维度处理axb*bxa=axa
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C) # reshsape拼接
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
  • 输入x:【B,196+1,768】 为197个 图片序列
  • 这里qkv函数是一个没有激活函数的线性层,输出是3xembed_dim,即表示Q K V的数据用3个维度表示,最后分配给QKV
  • reshape函数并把每个序列的数据分配给每一个Attn
  • permute 则是调整各个维度的位置
batch_size, num_patches + 114x14+1, total_embed_dim]
B, N, C = x.shape=[B,196+1768]

[这里的1是增加的位置编码,在class VisionTransformer(nn.Module):中实现。

# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
  • 上面init提到qkv()是一个dimx3*dimLinear。作用是得到qjv序列。shape为[batch_size, num_patches + 1, 3 * total_embed_dim]
  • reshaoe()做维度变换。3->q\k\v三个序列,num_heads定义编码器个数,C // self.num_heads每个编码器处理的序列。
  • permute(2, 0, 3, 1, 4)通过该函数调整维度位置。[3,batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  • q, k, v = qkv[0], qkv[1], qkv[2]通过切片的方式得到上一点提到的qkv三个序列

Attention计算公式
、||

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
attn = self.attn_drop(attn)

q\k\v 的维度为[batch_size, num_heads, num_patches + 1, embed_dim_per_head]
若要相乘,最后两个维度不满足矩阵乘法。

  • transpose(-2, -1) 相乘是的维度处理axb*bxa=axa
  • self.scale 是1/序列的范式。1/序列的长度。1/根号k
  • softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
       x = (attn @ v).transpose(1, 2).reshape(B, N, C) # reshsape拼接
        x = self.proj(x)
        x = self.proj_drop(x)
  • reshape(B, N, C) # reshsape拼接
  • self.proj(x)维度不变的全连接层

多层感知机:

class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
  • 这个就是简单的MLP,有两个线性层,激活函数是GELU
  • 第一个线性层输出维度是输入的4倍

接下来是Block模块,VIT需要12个模块堆叠,每个模块需要上面的类来实现。

注意力模块:

class Block(nn.Module):
    def __init__(self,
                 dim, # 输入维度768
                 num_heads, # 编码器个数
                 mlp_ratio=4., # 第一个全连接是输入4被
                 qkv_bias=False,
                 qk_scale=None, # 可输入向量长度
                 drop_ratio=0., # MLP一个连接层后
                 attn_drop_ratio=0., # a softmax后
                 drop_path_ratio=0., # attention 和 mlp后 ,可选
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

失活设置

# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()

定义类中的drop_path方法,当drop_path_ratio>0时,定义self.drop_path为自定义的Dropath,当其不满足条件时定义为nn.Identity(),可理解为不对网络进行操作。

mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
  • mlp_ratio=4,对应上文说到,MLP 第一个线性层的输出(即隐藏层)为输入的4倍。
  • 实现MLP

前向传播

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

注意力模块和MLP模块后都有一个失活操作,这个实现取决于上一步中提到drop_path_ratio的取值。在class VisionTransformer(nn.Module):类中对其值在一定区间递进变化。
每个模块都是Attn+MLP 作为一个Block。

VIT搭建

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02) # 参数 截断初始化正态分布
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x


def __init__(self, img_size=224,
								  patch_size=16, # 小图大小 决定序列大小为16x16
									in_c=3,# 图片通道
									 num_classes=1000,# 数据集分类数
                    embed_dim=768,# 序列大小
                    depth=12, # (Attention+MLP)模块数 
                    num_heads=12,# 每个MH 编码器个数
                     mlp_ratio=4.0,# 隐藏层通道
                      qkv_bias=True,
                 qk_scale=None,# q\k\v向量长度
                  representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., # attention dropout rate
                 drop_path_ratio=0., #stochastic depth rate
                 embed_layer=PatchEmbed,# 补丁嵌入实现
                  norm_layer=None,
                 act_layer=None):
self.num_tokens = 2 if distilled else 1 # 

考虑其他模型设置,VITself.num_tokens =1。即CLS的个数。

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

定义正则化成个激活函数。 若没有传入即为or后的实现。

self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)

传入图片并进行序列转换

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)
  • 根据未定设计,定义零张量。空张量。
  • 定义空的位置编码张量pos_embed 和drop层 pos_drop
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule

.> 获取一系列失活率dpr。应用在 block中Attention和MLP后

self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])

实现depth=12个blocks堆叠

Representation layer

        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

分类头

 # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

self.num_features 默认=嵌入维度embed_dim,如果分类个数num_classes=0,则不设置分类头

权重初始化

        nn.init.trunc_normal_(self.pos_embed, std=0.02)# 参数 截断初始化正态分布
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

截断正态分布 nn.init.trunc_normal_(self.pos_embed, std=0.02)

(function) trunc_normal_: (tensor: Tensor, mean: float = 0, std: float = 1, a: float = -2, b: float = 2) -> Tensor

self.apply(_init_vit_weights)pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作

_初始化权重:

输入m为模型。

def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01) #截断正态分布
        if m.bias is not None:
            nn.init.zeros_(m.bias) #偏置0
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)#置一

向前传播

 x = self.forward_features(x)

内嵌前向特征函数

def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

这是对上文的所有模块进行实现

  • self.patch_embed(x)载入图片并处理
    +expand(x.shape[0], -1, -1)修改维度
  • x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 加入位置编码 -1表示自适应张量原来的维度
 x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]
  • self.pos_drop(x + self.pos_embed)嵌入位置编码序列并drop。
  • 实现blocks,包含12个attention+mlp的block。
  • self.norm(x)网络正则化
  • return x[:, 0], x[:, 1]返回第2个维度上,类似于第0列和第1列的值。即为模型输出

你可能感兴趣的:(pytorch,深度学习,python)