Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
这是一篇奠定了Transformer在视觉领域击败传统卷积的文章,Transformer在NLP领域大放异彩之后,在视觉领域也取得了优异的效果,作者摒弃了所有的卷积操作,将图片分割为若干patch,再进行编码,像文本序列一样输入进Transformer模型中。在中等规模的数据集上取得的效果并没有卷积的效果好,但是在大规模的数据集上的表现已经能够超越卷积。
假设一张图片的尺寸为 H × W × C H \times W \times C H×W×C,patch的尺寸为 P × P P \times P P×P, 那么划分后的图片可以表示为 N × ( P 2 × C ) N \times (P^2 \times C) N×(P2×C), 其中 N = ( H × W ) / P 2 N = (H \times W) / P^2 N=(H×W)/P2。那么一个patch的初始编码长度就等于 ( P 2 × C ) (P^2 \times C) (P2×C)对其进行线性投影和位置编码之后就可以像训练文本一样。此外,如图所示,输入进网络中的patch有九个,但是对于最后用哪个编码结果进行图像分类是很难决定的,于是在网络中额外输入一个用于分类的cls_token,它的维度与patch是一致的,我们可以认为它是一个用于最终分类的人为添加的patch。
注意力机制并不是第一次用于图像处理中,SE(sequeeze and excitation)块其实也是一种注意力机制,不过它是作用于通道维的,而ViT是作用于全局的。每个patch都能与任意通道的patch做注意力。其实,当patch的形状是1x1时,效果就和SE块很类似了。
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torchvision
from torch.utils import data
import matplotlib.pyplot as plt
import copy
import math
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super(PreNorm, self).__init__()
self.norm = nn.LayerNorm(normalized_shape=dim)
self.fn = fn
def forward(self, x):
x = self.norm(x)
x = self.fn(x)
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super(FeedForward, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_features=dim, out_features=hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(in_features=hidden_dim, out_features=dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.net(x)
return x
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super(Attention, self).__init__()
inner_dim = heads * dim_head
project_out = not(heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** (-0.5)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
# x [batch_size, 查询个数, dim]
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t:rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # q,k,v维度相等 [batch_size, num_heads, 查询个数, d]
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super(Transformer, self).__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
])
)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self,image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
super(ViT, self).__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
) # h * w 等于patch的数量
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
# img [batch_size, c, H, W]
x = self.to_patch_embedding(img) # [batch_size, num_patch, dim]
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b) # [batch_size, 1, dim]
x = torch.cat((cls_tokens, x), dim=1) # [batch_size, 1 + num_patch, dim]
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
net = ViT(image_size=(224, 224), patch_size=(32, 32), num_classes=10, dim=768, depth=12, heads=12, mlp_dim=3072)