Vis-transformer的模型结构及数据流图梳理

vision_transformer顾名思义,是将自然语言处理中的Transformer的思想应用在图像分类中,将一张图像切成不同的patch之后作为sequence,使用Transformer来实现图像分类。

本文主要是梳理vis-transformer的模型结构,以及图像数据在各种算子operator下的shape变化。方便理解和进一步开发。

原论文的模型结构示意图:

Vis-transformer的模型结构及数据流图梳理_第1张图片

 

第一步:

将(1, 3, 224, 224)的图像,通过Conv2d, kernel=(16,16),stride=(16,16),变成(1,196,768)的数组形状,其中196可以理解为将原图像的H,W分别安装16切成的14*14个patches,768是定义的输出通道数。

第二步:

然后将类别的cls_token与卷积后的feature数据进行concat, 再add上pos的编码信息。

第三步:

transformer的结构,需要循环执行12次。每个transformer block包括attention和MLP编码模块。

attention模块

数据先进行layerNorm操作,然后通过一个Liner将输入x的channel维度扩3倍,再通过reshape和transpose操作分别把这三个维度给query, key, value这三个变量

query和key进行点乘,再乘以scale, 进行softmax,

然后softmax的结果和value进行点乘,进行reshape,transpose回数据进来时候的尺寸

最后通过一个Linear算子,将数据和第一步进来的数据进行点加,开始传递给MLP模块

MLP模块

进来后数据首先进行一个layerNorm操作,然后通过两个FC(全连接层,也是linear算子)将数据channel进行先扩大在缩回的操作,中间使用GELU激活函数。

最后一个FC层后的数据和MLP刚进来的数据进行点加,将数据传进下一步骤。

第四步:

最后的过程,先提取第一步插入的第0个通道的cls_token,即为类别的编码信息

通过Linear将类别信息定位到具体的class_num上,再通过经典的softmax, argmax得到图像的分类结果。

Excalidraw | Hand-drawn look & feel • Collaborative • Secure

Vis-transformer的模型结构及数据流图梳理_第2张图片

你可能感兴趣的:(深度学习,transformer,深度学习,人工智能,注意力模型)