Vision Transformer 复现

Transformer在NLP上的运用示例
Vision Transformer 复现_第1张图片


Transformer运用在CV领域的难点
最主要的问题是将 2-D 的图片转换为 1-D的序列。若将 2-D 图片中的像素点直接拉直成 1-D 像素序列将会引起复杂度过高问题。
以图像分类任务为例:图像分类过程用到的输入图片尺寸为224x224,直接拉直后扔进transformer复杂度 O=224x224 = 50776 比现在能承受的 最大序列512还要大100倍。其余检测、分割等任务的输入尺寸已达到600x600甚至800x800或者更大。
想要将自注意力使用到CV任务中就需要对序列转换提出解决方案,将 2-D 图片转换得到的 1-D 序列长度缩小。

  • 更改输入源:将CNN得到的中间层特征图转为 1-D 序列来输入transformer
  • 孤立自注意力:针对输入图像某个局部的自注意力而非全图
  • 轴自注意力:将图像视为H*W的矩阵,high方向和wide方向分别作为 1-D 的输入

ViT中处理图像输入问题的方法:imgae-patch
patch大小控制着,每一个切分出来的图片块,展平之后向量维度是多少。
Vision Transformer 复现_第2张图片将每一个patch视为一个元素,相当于NLP处理的句子中的单词。由此应用在NLP上的transformer模型本身无需进行改动,便可以处理CV问题。

对于计算Patch Embeddings有多种方案,本案例使用 卷积操作来实现
上代码解决问题

class PatchEmbed(nn.Layer):
    '''
    2D image to patch embedding
    '''
    def __init__(self,
                img_size=224,
                patch_size=16,
                input_channel=3,
                embed_dim=16*16*3,
                norm_layer = None):
        super().__init__()
        
        img_size = (img_size,img_size)
        patch_size = (patch_size,patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0]//patch_size[0],img_size[1]//patch_size[1]) #整除获得image_patch个数应有多少
        self.num_patch = self.grid_size[0] * self.grid_size[0]

        self.proj = nn.Conv2D(in_channels=input_channel,
                            out_channels=embed_dim,
                            kernel_size=patch_size,
                            stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() #Identity为占位层

    def forward(self,x):
        B,C,H,W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1],\
            print('input size error')
        #flatten : [B,C,H,W] -> [B,C,HW] B = batch_size
        #transpose : [B,C,HW] -> [B,HW,C]
        x = self.proj(x).flatten(2)
        x = x.transpose([0,2,1])
        x = self.norm(x)
        print('img_num,token_num ,embedding_dimensions ',x.shape)
        return x

ViT的输入
(image_patch输入到模型还需要注意)
linear projection 之前 需要将原始切分的块进行flatten拉直。同时注意是有位置信息的,因为是图片,是有位置关系的。
下图中提到的linear projection过程是什么?(论文中使用的是768个神经元的全连接层)

Vision Transformer 复现_第3张图片
Vision Transformer 复现_第4张图片
Vision Transformer 复现_第5张图片

Vision Transformer 复现_第6张图片

使用sefl-attention时,计算两个向量间常用的方案Relevant关系的方案——Dot product

你可能感兴趣的:(学习笔记,transformer,深度学习)