ViT代码解读

读懂VIT

    • 整体思路
    • 切块操作
    • 位置编码添加
    • 多头注意力机制

整体思路

Vision Transformer 是将Transformer应用在计算机视觉中。Transformer是一个基于注意力的模型,他不依靠卷积神经网络,相比RNN,他可以进行并行运算;相比CNN,计算两者的关系,不会受到距离的远近而增加计算的长度;同时自注意力可以产生更具可解释性的模型。我们可以从模型中检查注意力分布。各个注意头(attention head)可以学会执行不同的任务。虽然Transformer有这么多的优点,但是将其应用到计算机视觉也存在一定的问题,由于在NLP任务中,句子的长度是并不是很长,对于图像,如果以像素为计算单元,一张图片的像素太多,计算量巨大,所以ViT提出将图像进行切块,进行操作。

切块操作

将图片切成相同大小的patch块,例如一张224x224的图片,切成16x16的块,则可以切成14x14块,对与每一个patch块,展平成1x768的序列,每一个序列前边加一个cls-token,最终将获得196个1x769的序列。

切块操作 代码片.

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
    super().__init__()
        img_size = to_2tuple(img_size)  
        patch_size = to_2tuple(patch_size)  
        num_patches = (img_size[1] 
                img_size[0] 
        self.patch_shape = (img_size[0] 
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
            def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  
        return x

位置编码添加

在NLP语言中,由于Transformer 不像RNN具有时序的关系,他是并行的输入,所以需要确定前后关系,于是提出了位置信息。在视觉中的Transformer中,也是需要添加位置信息,对与每一个patch块进行位置信息添加。

代码片.


多头注意力机制

将每一个patch分为num_heads份进行注意力的计算,单独的计算每一份的注意力权重,这里用的是自注意力机制,根据下方公式
计算注意力。ViT代码解读_第1张图片

代码片.

class Attention(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
            proj_drop=0., attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim
        if attn_head_dim is not None:
            
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads  
        self.scale = qk_scale or head_dim ** -0.5  

        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)  
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))  
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))  
        else:
            self.q_bias = None
            self.v_bias = None

        self.attn_drop = nn.Dropout(attn_drop)  
        self.proj = nn.Linear(all_head_dim, dim)  
        self.proj_drop = nn.Dropout(proj_drop)  #

    def forward(self, x):
        '''
        B,C,H,W-> B,N,C
        :param x:
        :return:
        '''
        B, N, C = x.shape
        qkv_bias = None 
        if self.q_bias is not None:
            
            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
       
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # Batchsize patch数量 3(qkv) 头数(每个atten切分为几头) 宽高通道自适应  -->3(qkv) Batchsize 头数(每个atten切分为几头)patch数量 宽高通道自适应
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)  获得每个atten的值

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  

        attn = attn.softmax(dim=-1)  
        attn = self.attn_drop(attn)  

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)  
        x = self.proj(x) 
        x = self.proj_drop(x)  
        return x

你可能感兴趣的:(transformer,深度学习)