《AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》
《每个图片都可以看作16x16的方格(patch),用于大规模图像识别的Transformers》
Transformer模型在NLP领域取得巨大成功,所以学者也想把注意力机制引入计算机视觉领域。
因为适用于NLP的Transformer模型已经较为成熟,所以将其引入到CV领域时,最直接的方法就是将图像的输入形式与模型相匹配,这样就可以直接运行。
核心思想:
把图片当作一列数据,引用注意力机制对其进行训练和测试。
输入形式:
如果把一副图像的全部像素进行排列,会导致整个输入序列过长、复杂度太高( M × N M \times N M×N, M M M和 N N N分别表示图像的长和宽);创新点:
所以就提出了将图像划分为小方格的思路,每个小方格的大小就是 16 × 16 16 \times 16 16×16,当一副图像的尺寸为 224 × 224 224 \times 224 224×224时,输入序列长度就变成 224 16 × 224 16 = 14 × 14 = 196 \frac{224}{16} \times \frac{224}{16}=14 \times 14 =196 16224×16224=14×14=196。这个方格大小可以根据输入图像的尺寸进行改变。
整体架构:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
# 判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1) # 对张量进行分块
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) # 三倍维度分给多个头的QKV
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # 注意力机制公式
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
# depth表示tranformer堆叠几个,heads表示多少个头
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls',
channels=3, dim_head=64, dropout=0., emb_dropout=0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
训练代码:
# encoding=UTF-8
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import pandas as pd
import csv
import time
from models import *
from utils import progress_bar
from randomaug import RandAugment
from models.vit import ViT
from models.convmixer import ConvMixer
# parsers
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # resnets.. 1e-3, Vit..1e-4
parser.add_argument('--opt', default="adam")
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--noaug', action='store_true', help='disable use randomaug')
parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
parser.add_argument('--nowandb', action='store_true', help='disable wandb')
parser.add_argument('--mixup', action='store_true', help='add mixup augumentations')
parser.add_argument('--net', default='vit')
parser.add_argument('--bs', default='64')
parser.add_argument('--size', default="32")
parser.add_argument('--n_epochs', type=int, default='200')
parser.add_argument('--patch', default='4', type=int, help="patch for ViT")
parser.add_argument('--dimhead', default="512", type=int)
parser.add_argument('--convkernel', default='8', type=int, help="parameter for convmixer")
args = parser.parse_args()
bs = int(args.bs)
imsize = int(args.size)
print('==> Preparing data..')
if args.net == "vit_timm":
size = 384
else:
size = imsize
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Resize(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Prepare dataset
trainset = torchvision.datasets.CIFAR10(root='./Datasets', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./Datasets', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = ViT(
image_size=size,
patch_size=args.patch,
num_classes=10,
dim=int(args.dimhead),
depth=6,
heads=8,
mlp_dim=512,
dropout=0.1,
emb_dropout=0.1
)
# For Multi-GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
if 'cuda' in device:
print(device)
print("using data parallel")
net = torch.nn.DataParallel(net) # make parallel
cudnn.benchmark = True
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net))
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
# Loss is CE
criterion = nn.CrossEntropyLoss()
if args.opt == "adam":
optimizer = optim.Adam(net.parameters(), lr=args.lr)
elif args.opt == "sgd":
optimizer = optim.SGD(net.parameters(), lr=args.lr)
# use cosine scheduling
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs)
# Training
use_amp = bool(~args.noamp)
aug = args.noaug
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def train():
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
# Train with amp
with torch.cuda.amp.autocast(enabled=use_amp):
outputs = net(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
return train_loss / (batch_idx + 1)
def test():
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
# Save checkpoint.
acc = 100. * correct / total
if acc > best_acc:
print('Saving..')
state = {"model": net.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict()}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/' + args.net + '-{}-ckpt.t7'.format(args.patch))
best_acc = acc
os.makedirs("log", exist_ok=True)
content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
print(content)
with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender:
appender.write(content + "\n")
return test_loss, acc
list_loss = []
list_acc = []
net.cuda()
for epoch in range(start_epoch, args.n_epochs):
start = time.time()
trainloss = train()
val_loss, acc = test()
# scheduler.step(epoch - 1) # step cosine scheduling
scheduler.step()
list_loss.append(val_loss)
list_acc.append(acc)
# Write out csv..
with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f:
writer = csv.writer(f, lineterminator='\n')
writer.writerow(list_loss)
writer.writerow(list_acc)
# print(list_loss)
提示1:当PyCharm执行os.system出现中文乱码时,可以在File->Settings->Editor->File Encodings,把Global Encoding设置成GBK即可。
提示2:第一次使用wandb时,会报错提示需要授权,可以按照最后一个参考链接进行操作。当获取到授权码时,我直接粘贴反而没反应,就手动输入授权码即可。
提示3:安装vit-pytorch、odach和wandb时,可以采用pip顺利安装。