Vision Transformer 是将Transformer应用在计算机视觉中。Transformer是一个基于注意力的模型,他不依靠卷积神经网络,相比RNN,他可以进行并行运算;相比CNN,计算两者的关系,不会受到距离的远近而增加计算的长度;同时自注意力可以产生更具可解释性的模型。我们可以从模型中检查注意力分布。各个注意头(attention head)可以学会执行不同的任务。虽然Transformer有这么多的优点,但是将其应用到计算机视觉也存在一定的问题,由于在NLP任务中,句子的长度是并不是很长,对于图像,如果以像素为计算单元,一张图片的像素太多,计算量巨大,所以ViT提出将图像进行切块,进行操作。
将图片切成相同大小的patch块,例如一张224x224的图片,切成16x16的块,则可以切成14x14块,对与每一个patch块,展平成1x768的序列,每一个序列前边加一个cls-token,最终将获得196个1x769的序列。
切块操作 代码片
.
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1]
img_size[0]
self.patch_shape = (img_size[0]
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
在NLP语言中,由于Transformer 不像RNN具有时序的关系,他是并行的输入,所以需要确定前后关系,于是提出了位置信息。在视觉中的Transformer中,也是需要添加位置信息,对与每一个patch块进行位置信息添加。
代码片
.
将每一个patch分为num_heads份进行注意力的计算,单独的计算每一份的注意力权重,这里用的是自注意力机制,根据下方公式
计算注意力。
代码片
.
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop) #
def forward(self, x):
'''
B,C,H,W-> B,N,C
:param x:
:return:
'''
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# Batchsize patch数量 3(qkv) 头数(每个atten切分为几头) 宽高通道自适应 -->3(qkv) Batchsize 头数(每个atten切分为几头)patch数量 宽高通道自适应
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 获得每个atten的值
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x