ViT论文学习笔记

《AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》
《每个图片都可以看作16x16的方格(patch),用于大规模图像识别的Transformers》

前言

Transformer模型在NLP领域取得巨大成功,所以学者也想把注意力机制引入计算机视觉领域。

因为适用于NLP的Transformer模型已经较为成熟,所以将其引入到CV领域时,最直接的方法就是将图像的输入形式与模型相匹配,这样就可以直接运行。


模型

核心思想:把图片当作一列数据,引用注意力机制对其进行训练和测试。
输入形式:如果把一副图像的全部像素进行排列,会导致整个输入序列过长、复杂度太高( M × N M \times N M×N M M M N N N分别表示图像的长和宽);创新点:所以就提出了将图像划分为小方格的思路,每个小方格的大小就是 16 × 16 16 \times 16 16×16,当一副图像的尺寸为 224 × 224 224 \times 224 224×224时,输入序列长度就变成 224 16 × 224 16 = 14 × 14 = 196 \frac{224}{16} \times \frac{224}{16}=14 \times 14 =196 16224×16224=14×14=196。这个方格大小可以根据输入图像的尺寸进行改变。
整体架构:
ViT论文学习笔记_第1张图片

  1. 将图片拆分为小方格之后,经过一个线性投影层(全连接层,论文中用 E E E来表示, E E E的维度用 D D D来表示)就可以得到其特征信息(patch embedding);
  2. 由于方格之间存在位置顺序关系,所以添加位置信息(Position embedding);
  3. 为了便于从最后得到的特征中选择输出,所以添加特殊分类字符(Extralearnable
    [class] embedding),并且排列在最前面,位置编号为0;【因为cls与原本的9个方格之间交互信息,所以认为cls学习到了整体的信息】
  4. 224 × 224 224 \times 224 224×224的图像为例,将其按照 16 × 16 16 \times 16 16×16划分为长度为196的序列,每个方格的维度为 16 × 16 × 3 = 768 16 \times 16 \times 3=768 16×16×3=768,所以整个序列的尺寸为 196 × 768 196 \times 768 196×768,经过尺寸为 768 × 768 768 \times 768 768×768的全连接层 E E E,输出仍为 196 × 768 196 \times 768 196×768,再加上1个cls,最终进入Transformer的尺寸是 197 × 768 197 \times 768 197×768(Embedded Patches);
  5. Transfomer的结构如上图右侧所示,先经过LayerNorm,然后经过多头注意力机制。当头数为12时,K、Q、V的维度就为 197 × 768 12 = 197 × 64 197 \times \frac{768}{12}=197 \times 64 197×12768=197×64,最终再拼接回 197 × 768 197 \times 768 197×768
  6. 经过MLP时,通常会将维度先放大4倍 197 × 3072 197 \times 3072 197×3072,再投影回去 197 × 768 197 \times 768 197×768

上述过程用公式表示如下图所示:
ViT论文学习笔记_第2张图片


消融实验

  • 输出方式:通过实验查看将全部输出做全局平均池化(glabally average-pooling)的效果,并与采用class token的效果进行对比,结果是两者的效果几乎相同;
  • 位置编码方式:1D编码、2D编码、相对位置编码(Relative positional embeddings)
    ViT论文学习笔记_第3张图片
    由图可见,不同编码的效果相差无几。
    为了与标准的Transformer模型相匹配,论文最终选择class token和1D编码的方式。
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers
# 判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1) #  对张量进行分块
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) # 三倍维度分给多个头的QKV

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # 注意力机制公式

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# depth表示tranformer堆叠几个,heads表示多少个头
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls',
                 channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

训练代码:

# encoding=UTF-8
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pandas as pd
import csv
import time

from models import *
from utils import progress_bar
from randomaug import RandAugment
from models.vit import ViT
from models.convmixer import ConvMixer

# parsers
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')  # resnets.. 1e-3, Vit..1e-4
parser.add_argument('--opt', default="adam")
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--noaug', action='store_true', help='disable use randomaug')
parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
parser.add_argument('--nowandb', action='store_true', help='disable wandb')
parser.add_argument('--mixup', action='store_true', help='add mixup augumentations')
parser.add_argument('--net', default='vit')
parser.add_argument('--bs', default='64')
parser.add_argument('--size', default="32")
parser.add_argument('--n_epochs', type=int, default='200')
parser.add_argument('--patch', default='4', type=int, help="patch for ViT")
parser.add_argument('--dimhead', default="512", type=int)
parser.add_argument('--convkernel', default='8', type=int, help="parameter for convmixer")

args = parser.parse_args()

bs = int(args.bs)
imsize = int(args.size)

print('==> Preparing data..')
if args.net == "vit_timm":
    size = 384
else:
    size = imsize

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Prepare dataset
trainset = torchvision.datasets.CIFAR10(root='./Datasets', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./Datasets', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = ViT(
    image_size=size,
    patch_size=args.patch,
    num_classes=10,
    dim=int(args.dimhead),
    depth=6,
    heads=8,
    mlp_dim=512,
    dropout=0.1,
    emb_dropout=0.1
)

# For Multi-GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

if 'cuda' in device:
    print(device)
    print("using data parallel")
    net = torch.nn.DataParallel(net)  # make parallel
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

# Loss is CE
criterion = nn.CrossEntropyLoss()

if args.opt == "adam":
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
elif args.opt == "sgd":
    optimizer = optim.SGD(net.parameters(), lr=args.lr)

# use cosine scheduling
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs)

# Training
use_amp = bool(~args.noamp)
aug = args.noaug
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


def train():
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    return train_loss / (batch_idx + 1)


def test():
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving..')
        state = {"model": net.state_dict(),
                 "optimizer": optimizer.state_dict(),
                 "scaler": scaler.state_dict()}
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/' + args.net + '-{}-ckpt.t7'.format(args.patch))
        best_acc = acc

    os.makedirs("log", exist_ok=True)
    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
    print(content)
    with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender:
        appender.write(content + "\n")
    return test_loss, acc


list_loss = []
list_acc = []

net.cuda()
for epoch in range(start_epoch, args.n_epochs):
    start = time.time()
    trainloss = train()
    val_loss, acc = test()

    # scheduler.step(epoch - 1)  # step cosine scheduling
    scheduler.step()

    list_loss.append(val_loss)
    list_acc.append(acc)

    # Write out csv..
    with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerow(list_loss)
        writer.writerow(list_acc)
    # print(list_loss)


提示1:当PyCharm执行os.system出现中文乱码时,可以在File->Settings->Editor->File Encodings,把Global Encoding设置成GBK即可。
提示2:第一次使用wandb时,会报错提示需要授权,可以按照最后一个参考链接进行操作。当获取到授权码时,我直接粘贴反而没反应,就手动输入授权码即可。
提示3:安装vit-pytorch、odach和wandb时,可以采用pip顺利安装。


总结

ViT论文学习笔记_第4张图片

  1. ViT模型在中小型数据集上做预训练时,效果不如ResNet模型(BiT),这是因为ViT比CNN针对图像而言有更少的归纳偏置(Inductive bias),位置信息编码在初始化时没有携带任何的2维图像块信息,所有位置信息的空间关系都需要从头开始学习。
  2. 随着数据集的不断变大(21k增加至300M)时,ViT的预训练效果基本全面超过ResNet模型。所以想要充分发挥ViT的效果和性能,超大数据集是必不可少的。

ViT论文学习笔记_第5张图片

  1. ViT与ResNet相比在性能和计算资源上占据主要地位。ViT仅用2-4倍或者更少的计算就达到与其相同的性能;
  2. Hybrid Architecture(CNN特征图+Transformer)在小数据集上的性能优于ViT,但是在大数据集上这种差异就会消失;
  3. ViT在所有的尝试上都没有表现出任何的饱和状态,潜力无穷。

参考

  • https://www.bilibili.com/video/BV15P4y137jb/?spm_id_from=333.788&vd_source=94f79d8adeec4791b8751d7cb539ce55&t=1901.3
  • https://blog.csdn.net/weixin_43312117/article/details/124085898
  • https://www.bilibili.com/video/BV1Uu411o7oY?p=2&spm_id_from=pageDriver&vd_source=94f79d8adeec4791b8751d7cb539ce55
  • https://blog.csdn.net/weixin_44966641/article/details/118733341
  • https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
  • https://blog.csdn.net/weixin_46088823/article/details/121766888

你可能感兴趣的:(学习,深度学习,人工智能,计算机视觉,transformer)