这是一个针对ision transformer 模型的解析笔记。关于模型的框架可以参考霹雳吧啦的图如下。
VIT rwightman版源码
代码较多,本文只抓点分析,不一一张贴。自行对照
这里code里面用了一个DropPath
类来实现。考虑是继承nn.module
来实现随机失活。根据函数名,推测这里应该是随机失活分支。其中前向函数为
keep_prob = 1 - drop_prob # 分别表示 生存概率和失活概率。两者互不
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device),
random_tensor.floor_() # binarize
keep_prob
表示生存概率,当失活概率drop_prob
小于0.5,即生存概率较大,此时random_tensor
中每个随机向量都加上一个大于0.5的值。在经过floor_()
向下取整时将保留生存概率>0.5 的元素,如此实现失活。这里说明一点,torch.rand()
产生的张量元素随机<1,如此有随机.
为了在Transfromer上对图片数据进行训练,需要把图片处理成序列。好像是图片利用14宫格把224x224的图片分成768张小图。再把这些小图送进编码器进行处理。具体实现是利用16x16 s16
的卷积把224x224
的图片进行处理、展平、移项。
class PatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1] 图片处理成num_patches个序列。
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
并通过一个卷积进行影射。输入in_c=3,对应图片通道, embed_dim
=小图的个数为16x16
即序列的维度, kernel_size
=stride
=16,是为了根据outSize=(inSize-k+2p)/s+1=(224-16+0)/16+1=14
满足输出小图14x14的相应设计
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
展平为 [B, C, HW]
形式是因为,VIT 本就是类似把图片当作文本处理,故把高宽展平为一个维度有利与计算。B:batchSize,C:channeel=embed_dim=768
,HW=196
。
这里我认为768更像是通过卷积的方式,把每个小图的像素存入768个元素中,196x768更像是196个小图的数据。
class Attention(nn.Module):
def __init__(self,
dim, # 输入token的dim(维度)
num_heads=8, #多头编码的数量
qkv_bias=False,
qk_scale=None, # qk相乘的大小
attn_drop_ratio=0., # 多头编码合并后做失活
proj_drop_ratio=0.): # MLP 后做失活
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads # 可看作取整数部分,向下取整
self.scale = qk_scale or head_dim ** -0.5 # 当没有赋值时, -0.5次方运算。
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)# 等于3个(q k v) dim节点数的linear
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim, dim) # dim=768,维度不变
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):# 197x768
# [batch_size, num_patches + 1(14x14+1), total_embed_dim]
B, N, C = x.shape
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]X3->q\k\v
# 前面3dim长度的linear同时影射得到 qkV的向量。经过上面处理后第0个维度为3,分别表示qkv。
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] # 相乘是的维度处理axb*bxa=axa
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # reshsape拼接
x = self.proj(x)
x = self.proj_drop(x)
return x
为了生成Q、K、V三个向量需要三个线性Linear
层。这里dim*3
增加了节点数,代替了三个线性层,得到QKV. dim=768
表示序列长度
def forward(self, x):# 197x768
# [batch_size, num_patches + 1(14x14+1), total_embed_dim]
B, N, C = x.shape
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]X3->q\k\v
# 前面3dim长度的linear同时影射得到 qkV的向量。经过上面处理后第0个维度为3,分别表示qkv。
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] # 相乘是的维度处理axb*bxa=axa
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # reshsape拼接
x = self.proj(x)
x = self.proj_drop(x)
return x
- 输入
x:【B,196+1,768】
为197个 图片序列- 这里qkv函数是一个没有激活函数的线性层,输出是3xembed_dim,即表示Q K V的数据用3个维度表示,最后分配给QKV
- reshape函数并把每个序列的数据分配给每一个Attn
- permute 则是调整各个维度的位置
batch_size, num_patches + 1(14x14+1), total_embed_dim]
B, N, C = x.shape=[B,196+1,768]
[这里的1是增加的位置编码,在class VisionTransformer(nn.Module):
中实现。
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
- 上面init提到
qkv()
是一个dimx3*dim
的Linear
。作用是得到qjv序列。shape为[batch_size, num_patches + 1, 3 * total_embed_dim]
reshaoe()
做维度变换。3->q\k\v三个序列,num_heads
定义编码器个数,C // self.num_heads
每个编码器处理的序列。permute(2, 0, 3, 1, 4)
通过该函数调整维度位置。[3
,batch_size, num_heads, num_patches + 1, embed_dim_per_head]q, k, v = qkv[0], qkv[1], qkv[2]
通过切片的方式得到上一点提到的qkv三个序列
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) # dim=-1,每一行做softmax处理 dim=-2 列处理
attn = self.attn_drop(attn)
q\k\v 的维度为[batch_size, num_heads, num_patches + 1, embed_dim_per_head]
若要相乘,最后两个维度不满足矩阵乘法。
transpose(-2, -1
) 相乘是的维度处理axb*bxa=axaself.scale
是1/序列的范式。1/序列的长度。1/根号ksoftmax(dim=-1)
# dim=-1,每一行做softmax处理 dim=-2 列处理
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # reshsape拼接
x = self.proj(x)
x = self.proj_drop(x)
reshape(B, N, C)
# reshsape拼接self.proj(x)
维度不变的全连接层
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
- 这个就是简单的MLP,有两个线性层,激活函数是
GELU
,- 第一个线性层输出维度是输入的4倍
接下来是Block模块,VIT需要12个模块堆叠,每个模块需要上面的类来实现。
class Block(nn.Module):
def __init__(self,
dim, # 输入维度768
num_heads, # 编码器个数
mlp_ratio=4., # 第一个全连接是输入4被
qkv_bias=False,
qk_scale=None, # 可输入向量长度
drop_ratio=0., # MLP一个连接层后
attn_drop_ratio=0., # a softmax后
drop_path_ratio=0., # attention 和 mlp后 ,可选
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
失活设置
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
定义类中的drop_path方法,当
drop_path_ratio
>0时,定义self.drop_path
为自定义的Dropath
,当其不满足条件时定义为nn.Identity()
,可理解为不对网络进行操作。
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
mlp_ratio
=4,对应上文说到,MLP 第一个线性层的输出(即隐藏层)为输入的4倍。- 实现MLP
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
注意力模块和MLP模块后都有一个失活操作,这个实现取决于上一步中提到
drop_path_ratio
的取值。在class VisionTransformer(nn.Module):
类中对其值在一定区间递进变化。
每个模块都是Attn+MLP 作为一个Block。
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_c (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_ratio (float): dropout rate
attn_drop_ratio (float): attention dropout rate
drop_path_ratio (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
"""
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# Weight init
nn.init.trunc_normal_(self.pos_embed, std=0.02) # 参数 截断初始化正态分布
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(_init_vit_weights)
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # [B, 196, 768]
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x
def __init__(self, img_size=224,
patch_size=16, # 小图大小 决定序列大小为16x16
in_c=3,# 图片通道
num_classes=1000,# 数据集分类数
embed_dim=768,# 序列大小
depth=12, # (Attention+MLP)模块数
num_heads=12,# 每个MH 编码器个数
mlp_ratio=4.0,# 隐藏层通道
qkv_bias=True,
qk_scale=None,# q\k\v向量长度
representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., # attention dropout rate
drop_path_ratio=0., #stochastic depth rate
embed_layer=PatchEmbed,# 补丁嵌入实现
norm_layer=None,
act_layer=None):
self.num_tokens = 2 if distilled else 1 #
考虑其他模型设置,VITself.num_tokens
=1。即CLS的个数。
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
定义正则化成个激活函数。 若没有传入即为
or
后的实现。
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
传入图片并进行序列转换
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
- 根据未定设计,定义零张量。空张量。
- 定义空的位置编码张量
pos_embed
和drop层pos_drop
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
.> 获取一系列失活率dpr
。应用在 block中Attention和MLP后
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
实现depth=12个blocks堆叠
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.num_features
默认=嵌入维度embed_dim
,如果分类个数num_classes
=0,则不设置分类头
nn.init.trunc_normal_(self.pos_embed, std=0.02)# 参数 截断初始化正态分布
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(_init_vit_weights)
截断正态分布
nn.init.trunc_normal_(self.pos_embed, std=0.02)
(function) trunc_normal_: (tensor: Tensor, mean: float = 0, std: float = 1, a: float = -2, b: float = 2) -> Tensor
self.apply(_init_vit_weights)
pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作
输入m
为模型。
def _init_vit_weights(m):
"""
ViT weight initialization
:param m: module
"""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.01) #截断正态分布
if m.bias is not None:
nn.init.zeros_(m.bias) #偏置0
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)#置一
x = self.forward_features(x)
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # [B, 196, 768]
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
这是对上文的所有模块进行实现
self.patch_embed(x)
载入图片并处理
+expand(x.shape[0], -1, -1)
修改维度x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
加入位置编码 -1表示自适应张量原来的维度
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
self.pos_drop(x + self.pos_embed)
嵌入位置编码序列并drop。- 实现blocks,包含12个attention+mlp的block。
self.norm(x)
网络正则化return x[:, 0], x[:, 1]
返回第2个维度上,类似于第0列和第1列的值。即为模型输出