对应论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
直接看代码
首先看Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
#ModuleList是一个存储不同module,并自动将每个模块的参数添加到网络之中的容器
#与sequential的区别是,它的模块之间并没有先后顺序,运行时可以改
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
self.layers中有多个类定义的对象,按照执行顺序,逐一解释。
Attention类
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads#
self.heads = heads
self.scale = dim ** -0.5
#dim是线性变换后输出张量的最后维度
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
#线性变换改变维度,chunk沿着指定轴最后一维分3块
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
#einsum对张量的求和运算,默认成对出现的下标为求和下标
mask_value = -torch.finfo(dots.dtype).max
#finfo 表示浮点的数值属性的对象
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
# pad是pytorch内置的填充函数
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
#张量求和
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
第二个是FeedForward类
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),#Gaussian Error Linerar Units 引入了随机正则的激活函数
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
简单看下自定义的两各函数,残差函数与归一化函数
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
下面详细描述主类中的使用方法
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
#randn从标准正态分布中抽取一组随机数
#1,64+1,1024,+1是因为token,可学习变量,不是固定编码
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
#分类输出位置标志,否则分类输出不知道应该取哪个位置
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
#identity建立一个输入模块,什么都不做,通常用在神经网络的输入层
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
#采用了爱因斯坦表达式,具体是采用了einops库实现,内部集成了各种算子,rearrange就是其中一个。
#p就是patch大小,假设输入是b,3,256,256,则rearrange操作是先变成(b,3,8x32,8x32)
#最后变成(b,8x8,32x32x3)即(b,64,3072),
#将每张图片切分成64个小块,每个小块长度是32x32x3=3072
#也就是说输入长度为64的图像序列,每个元素采用3072长度进行编码。
x = self.patch_to_embedding(x)
#3072有点大,用线性变换完成降维,变成(b,64,1024)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
#额外追加token,变成b,65,1024
x += self.pos_embedding[:, :(n + 1)]
#num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
#将patch嵌入向量和位置编码向量相加即可作为编码器输入
x = self.dropout(x)
x = self.transformer(x, mask)
# 采用的是没有改动的transformer
#假设输入是(b,65,1024),那么transformer的输出也是(b,65,1024)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
#65个输出里面只需要第0个输出进行后续分类即可
x = self.to_latent(x)
return self.mlp_head(x)
#在编码器后接fc分类器head
整篇文章使用transformer的思路非常简洁,有助于了解怎样将自然语言处理的transformer方法用于图像处理。