上图是Vision Transformer原文的模型结构展示,可以看到模型包含了几个核心模块:
Vision Transformer:
1. Embedding模块
2.Transformer Encoder模块
2.1 NormLayer ( × depth )
2.1.1 Multi-Head Attention层
关于Attention机制的详细解析
2.1.2 MLP多层感知器
3.MLP-Head 模块映射为类别
自底向上摸索是在未知中探索的不可缺少的方式,但通过摸索后,发现自顶向下能更好的阐述清楚整个逻辑。
假设训练数据维度为(64, 3, 224, 224),意味着有64张三通道的224*224的图像。
设定参数dim=128意味着编码向量长度为128。
ViT中出现的PreNorm、Attention、FeedForward、Transformer后续解释
class ViT(nn.Module):
'''
:param
*: input data
image_size: 等边图像尺寸
patch_size: patch的尺寸
num_classes: 分类类别
dim: 为每一个patch编码的长度
depth: Encoder的深度,也就是连接encoder的数目
heads: 多头注意力中头的数目
mlp_dim: 多层感知器中隐含层的维度
pool: 使用cls token还是使用均值池化
channel: 图像的通道数
dim_head: 注意力机制中一个头的输入维度
dropout: NormLayer中dropout的参数比例
emb_dropout: Embedding中的dropout比例
:return 分类结果(64, 2)
'''
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
# image_size就是每一张图像的长和宽,通过pair函数便捷明了的表现
# patch_size就是图像的每一个patch的长和宽
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
# 保证图像可以整除为若干个patch
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
# 计算出每一张图片会被切割为多少个patch
# 假设输入维度(64, 3, 224, 224), num_patches = 49
num_patches = (image_height // patch_height) * (image_width // patch_width)
# 每一个patch数组大小, patch_dim = 3*32*32=3072
patch_dim = channels * patch_height * patch_width
# cls就是分类的Token, mean就是均值池化
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# embeding操作:假设输入维度(64, 3, 224, 224),那么经过Rearange层后变成了(64, 7*7=49, 32*32*3=3072)
self.to_patch_embedding = nn.Sequential(
# 将图片分割为b*h*w个三通道patch,b表示输入图像数量
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
# 经过线性全连接后,维度变成(64, 49, 128)
nn.Linear(patch_dim, dim),
)
# dim张图像,每张图像需要num_patches个向量进行编码
# 位置编码(1, 50, 128) 本应该为49,但因为cls表示类别需要增加一个
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# CLS类别token,(1, 1, 128)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 设置dropout
self.dropout = nn.Dropout(emb_dropout)
# 初始化Transformer
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# pool默认是cls进行分类
self.pool = pool
self.to_latent = nn.Identity()
# 多层感知用于将最终特征映射为2个类别
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
# 第一步,原始图像ebedding,进行了图像切割以及线性变换,变成x->(64, 49, 128)
x = self.to_patch_embedding(img)
# 得到原始图像数目和单图像的patches数量, b=64, n=49
b, n, _ = x.shape
# (1, 1, 128) -> (64, 1, 128) 为每一张图像设置一个cls的token
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
# 将cls token加入到数据中 -> (64, 50, 128)
x = torch.cat((cls_tokens, x), dim=1)
# x(64, 50, 128)添加位置编码(1, 50, 128)
x += self.pos_embedding[:, :(n + 1)]
# 经过dropout层防止过拟合
x = self.dropout(x)
x = self.transformer(x)
# 进行均值池化
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
# 最终进行分类映射
return self.mlp_head(x)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
# 设定depth个encoder相连,并添加残差结构
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
# 每次取出包含Norm-attention和Norm-mlp这两个的ModuleList,实现残差结构
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class PreNorm(nn.Module):
'''
:param dim 输入维度
fn 前馈网络层,选择Multi-Head Attn和MLP二者之一
'''
def __init__(self, dim, fn):
super().__init__()
# LayerNorm: ( a - mean(last 2 dim) ) / sqrt( var(last 2 dim) )
# 数据归一化的输入维度设定,以及保存前馈层
self.norm = nn.LayerNorm(dim)
self.fn = fn
# 前向传播就是将数据归一化后传递给前馈层
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = heads * dim_head
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
# 表示1/(sqrt(dim_head))用于消除误差,保证方差为1,避免向量内积过大导致的softmax将许多输出置0的情况
# 可以看原文《attention is all you need》中关于Scale Dot-Product Attention如何抑制内积过大
self.scale = dim_head ** -0.5
# dim = > 0 时,表示mask第d维度,对相同的第d维度,进行softmax
# dim = < 0 时,表示mask倒数第d维度,对相同的倒数第d维度,进行softmax
self.attend = nn.Softmax(dim = -1)
# 生成qkv矩阵,三个矩阵被放在一起,后续会被分开
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# 如果是多头注意力机制则需要进行全连接和防止过拟合,否则输出不做更改
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
# 分割成q、k、v三个矩阵
# qkv为 inner_dim * 3,其中inner_dim = heads * dim_head
qkv = self.to_qkv(x).chunk(3, dim = -1)
# qkv的维度是(3, inner_dim = heads * dim_head)
# 'b n (h d) -> b h n d' 重新按思路分离出8个头,一共8组q,k,v矩阵
# rearrange后维度变成 (3, heads, dim, dim_head)
# 经过map后,q、k、v维度变成(1, heads, dim, dim_head)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# query * key 得到对value的注意力预测,并通过向量内积缩放防止softmax无效化部分参数
# heads * dim * dim
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# 对最后一个维度进行softmax后得到预测的概率值
attn = self.attend(dots)
# 乘积得到预测结果
# out -> heads * dim * dim_head
out = torch.matmul(attn, v)
# 重组张量,将heads维度重新还原
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def pair(t):
return t if isinstance(t, tuple) else (t, t)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)