Transformer 在图像中的运用(一)VIT(Transformers for Image Recognition at Scale)论文及代码解读

接着前面的文章说到的transformer,本篇将要介绍在图像中如何将transformer运用到图片分类中去的。我们知道CNN具有平移不变形,但是transformer基于self-attentation可以获得long-range信息(更大的感受野),但是CNN需要更多深层的Conv-layers来不断增大感受野。

这里将给出论文地址及代码地址:
论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
官方代码地址:https://github.com/google-research/vision_transformer(本文讲解对象)
jeonsworld/ViT-pytorch(相对来说容易理解)

本博客讲解代码地址:https://github.com/lucidrains/vit-pytorch
这里选择代码地址的原因首先是因为其star比较高,其次拥有多种变形模型及使用的pytorch框架进行编写,便于代码阅读。

一、论文阅读


这里主要讲解论文重点的部分

1. 优点

相当于卷积模型来比,transformer在减少计算资源的同时获得了非常出色的结果。当对中等规模数据集(例如ImageNet)进行训练的时候,此模型所产生的适合的精度要比同等规模的ResNet低几个百分点。数据量越大,模型越友好。其中的Attention机制是一个很重要的机制我们通过一下图可以看出其优势, 让我们关注我们需要的物体,忽略没有用的东西。

Attention

传统的卷积神经网络需要大的感受野需要不断地卷积,才能获得更大的感受野,但是在我们Transformer的模式下,我们在浅层就已经获得很大的感受野,在做attention的时候就已经看到了所有的信息,所以说其根本不需要堆叠,直接就可以获得全局信息(这里看完文章就会有感悟)。
纵坐标代表感受野

2. 启发

很多图像特征提取器将CNN与专门注意力机制结合在一起,但是未能在硬件加速器扩展。

3. 与CNN的差距原因

transformer缺乏CNN固有的一些感应偏差,例如平移不变性和局部性,因此在训练不足的数量时候,很难有好的效果。

4. VIT原理:

vit.gif

  • 步骤一(input):
    image.png

    输入图像大小尺寸为(), 首先我们将图片进行切分,按照 patch_size进行切分,这样我们就得到了大小的一个个图块, 这里的图块数量为. 联想到transformer,这里的N就可以理解是序列长度,其中序列中每个element的维度dim称之为patch embedding
    在我们进行图片分类的时候我们一般在序列前加入一个element,我们称此element为, 这样我们得到序列长度为N+1,在训练的时候我们可以通过此element进行图片分类。最后再加上位置矩阵(注意这里是add不是concate)构成我们的输入矩阵z0

步骤二(forward):
transformer编码器主要由两个components构成分别是MSA(multi-head self-attention)MLP([MultiLayer Perceptron)组成。下面是前向传播的计算公式:

  • 第一个公式
    这里表示的是类别element,表示的是输入的每个patch,E代表的对应的权重,N表示的patch的数量 代表的是position的信息, 这里不加位置编码,加一维编码(即1,2,3。。。),以及加位置编码的效果如下, 我们发现有位置编码比没有的效果好,但是多少维的效果差不多,我们一般采用2维度。下面是分类不同维度编码效果,但是检测任务就不一定了。

    位置编码

  • 第二个公式
    这里的LN表示的Layer Normalization,MSA的公式如下:


    这里的qkv矩阵之前说过了,如通过输入z与权重得到而来,我们在通过公式得到我们的Attention权重。最终利用v矩阵与attention权重相乘得到。因为考虑到多头记住所以我们得到如下公式:
    MLP其实就是多层感知机,这里很容易理解,其包含具有GELU非线性的两全连接层。
    还有一点注意的是根据如下公式可以看到vit模型结构也采用了残差机制:
    image.png

5 微调和更高的分辨率

  • 微调:
    可删除预训练的head,附加初始化foward层, K是类的数量。

  • 更高分辨率:
    当提供更高分辨率的图像时,我们将图块大小保持不变,这会导致更大的有效序列长度。ViT可以处理任意序列长度(直到内存限制),但是,预训练的位置embedding可能不再有意义。因此,我们根据预先训练的位置嵌入在原始图像中的位置执行2D插值。请注意,只有在分辨率调整和色块提取中,将有关图像2D结构的感应偏差手动注入到Vision Transformer中。

二、代码解读

本博客讲解代码地址:https://github.com/lucidrains/vit-pytorch
这里主要做的是猫狗分类模型,图片大小为256*256

1. 主函数

这里的main函数可以理解为是常规操作,大家稍微看下就理解了。

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Maocheng Hu
@project_name:vit
@file:train.py
@time:2021/04/27/14/17
@ide:PyCharm
@email: [email protected]

              ┏┓      ┏┓
            ┏┛┻━━━┛┻┓
            ┃        ┃
            ┃  ┳┛  ┗┳  ┃
            ┃      ┻      ┃
            ┗━┓      ┏━┛
                ┃      ┗━━━┓
                ┃  神兽保佑    ┣┓
                ┃ 永无BUG!   ┏┛
                ┗┓┓┏━┳┓┏┛
                  ┃┫┫  ┃┫┫
                  ┗┻┛  ┗┻┛
"""

import os

import glob
from tqdm import tqdm
import torch
import torch.nn as nn
from PIL import Image

import torch.optim as optim
from linformer import Linformer
# from vit_pytorch.efficient import ViT
from vit_pytorch import ViT
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split


# image augmentation
def image_augmentation():
    train_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    val_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    test_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    return train_transforms, val_transforms, test_transforms


# load data
def load_data(dataset):
    train_list = glob.glob(os.path.join("{}/{}".format(dataset, "train"), "*.jpg"))
    test_list = glob.glob(os.path.join("{}/{}".format(dataset, "test"), "*.jpg"))
    train_label_list = [path.split('/')[-1].split('.')[0] for path in train_list]
    # stratify for balancing classes
    train_list, valid_list = train_test_split(train_list,
                                              test_size=0.2,
                                              stratify=train_label_list,
                                              random_state=2021)
    return train_list, valid_list, test_list


class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transformer = transform

    def __len__(self):
        self.file_length = len(self.file_list)
        return self.file_length

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transformer(img)
        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0
        return img_transformed, label





def main():
    dataset = "dataset"
    # Training settings
    batch_size = 1
    epochs = 20
    lr = 3e-5
    gamma = 0.7
    seed = 42
    device = 'cuda'

    train_transforms, val_transforms, test_transforms = image_augmentation()
    train_list, valid_list, test_list = load_data(dataset)
    train_data = CatsDogsDataset(train_list, transform=train_transforms)
    valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
    test_data = CatsDogsDataset(test_list, transform=test_transforms)
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=0)
    valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True, num_workers=0)

    # efficient_transformer = Linformer(
    #     dim=128,
    #     seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    #     depth=12,
    #     heads=8,
    #     k=64
    # )

    model = ViT(
        image_size=224,
        patch_size=32,
        num_classes=1000,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1
    ).to(device)

    # loss function
    criterion = nn.CrossEntropyLoss()
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # scheduler
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    for epoch in range(epochs):
        epoch_loss = 0
        epoch_accuracy = 0

        for data, label in tqdm(train_loader):
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            loss = criterion(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = (output.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)

        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in valid_loader:
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(valid_loader)
                epoch_val_loss += val_loss / len(valid_loader)

        print(
            f"Epoch : {epoch + 1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )


if __name__ == '__main__':
    main()

2. 输入构成

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2  # total patch dimension
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

关于输入我们只要看如下代码即可

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

首先我们先进行参数介绍

  • image_size 图片大小这里为224*224
  • patch_size 表示的是patch大小这里为 32 * 32,所以我们可以得到patch_num为7 * 7
  • num_classes 为图片类别数量 这里为 2
  • dim 表示的是序列每一个element的维度大小,这里为 1024
  • depth 表示的transformer模型的层数
  • heads 表示的是Multi-head Attention layer的head数,这里为16
  • mlp_dim MLP层的hidden dim
  • emb_dropout 对于输入做dropout

(1)确定patch size 大小

assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' # 注意这里的patch size的大小必须能被图片尺寸整除
        num_patches = (image_size // patch_size) ** 2 # 我们可以得到 num_patch数量为224^2//32^2 为7 * 7即49个。 

(2) 确定patch dim大小

 patch_dim = channels * patch_size ** 2   # 可以理解将图片所有像素重新排列等到patch的通道数即我们这一步可以得到(patch_num,   patch_dim), 这里的patch_dim = 3 * 32 * 32 = 3072

(3) 使用分类方法:

assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' # 这里我们使用 cls , 即单独用一个element来综合我们的特征信息,如果是用mean的话总是用我们的到的z(除去cls的特征信息的平局值)的平局值来综和我们的信息。如果这里不理解,后面也会介绍的。

(4) 维度转换:

 self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim),
        )
# 1. 这里的b代表batch_size, c代表channel,h以及w分别代表图像的高以及宽,p1及p2代表图像横纵切分的patch_size大小,所以Rerrange的矩阵大小为[batch_size, (7, 7),(32, 32, 3)] -->[batch_size, 49, 3072]
# 2. 经过nn.linear(3072, 1024), 最后我们得到(batch_size, 49, 1024)

也有的用卷积方式

self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)

(5) 加入类别位置


cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
#  这里的操作相当于给input(batch_size, 49, 1024)第二个维度在增加一个类别维度得到(batch_size, 50, 1024)

(6) 地址编码
如下图所示。我们发现,位置越接近,往往具有更相似的位置编码。此外,出现了行列结构;同一行/列中的patch具有相似的位置编码。


image.png
#  这里相当于进行位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
# 注意这里的位置编码是直接相加的

(7)输入加dropout

self.dropout = nn.Dropout(emb_dropout)
x = self.dropout(x)

最终我们得到的输入尺寸为(batch_size, 50, 1024)

3. 模型构建

通过上述的输入x,将输入到我们的transformer模型里

# 1. 输入模型得到结果
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
x = self.transformer(x)

# 2. Transformer模型
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 3.  Attentation机制
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 4. layer norm
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# 5. forward 机制
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# 6 输出
  x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
  return self.mlp_head(x)

下面我们分解说下

(1)输入模型得到结果
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
x = self.transformer(x)
(2) Transformer机制
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

这里我们主要关注两个部分:
第一个部分

PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))

第一个部分是attention机制, 第二个部分是forward机制。同时for循环,表示的多层机制。

第二个部分

x = attn(x) + x
x = ff(x) + x

这里使用残差的方式。

(3)Attention机制
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.): # dim为50, heads为16
        super().__init__()
        inner_dim = dim_head * heads  # 这里的inner_dim 为 单个 头维度64,heads为头数量(这里为16),所以inner_dim为1024
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # 这里的to_qkv为[50, 1024 * 3]

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads  # 这里的b为batch_size, n为50,  _为1024, h为16
        qkv = self.to_qkv(x).chunk(3, dim=-1) # 通过同一个线性并行权重计算得到[50, 1024*3], 再通过chunk最后一个维度切分得到([50, 1024], [50, 1024], [50, 1024])的tuple形式。
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) # 通过一下转换分别得到q, k, v矩阵, 就是讲[batch_size, 50, 16,  64], 转换成[batch_size, 16, 50, 64]

        # 下面需要根据公式进行qkv操作了,从而得到输出z。即z = softmax(Q * K^T/sqrt(d_k))V

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # 这里是将q与k相乘并除上sqrt(d_k)得到z_0
        attn = self.attend(dots) # 对z_0进行softmax得到z_1

        out = einsum('b h i j, b h j d -> b h i d', attn, v) # 这里将结果z_1与矩阵V进行相乘
        out = rearrange(out, 'b h n d -> b n (h d)') # 这里对最终结果进行rerange得到shap为[batch_size, 50, 16,*64]
        return self.to_out(out) # 过线性连接得到矩阵shape[1, 50, 1024]

这里是先用全连接在差分qkv矩阵也有分别在一开始直接用全连接生成qkv矩阵, 如下所示:

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

(4) layer norm

# 这里相当于每一次无论是attention还是forward都是首先对输入矩阵进行layer norm
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

(5)Feedforward
下面很容易理解了就是MLP的一部分了。相当于用了线性连接,这里dim=50, hidden_dim=64, 这里使用的激活函数为GELU。最终我们得到[batch_size, 50,1024]矩阵。记住这里的多层矩阵输出仍然是[batch_size, 50, 1024]

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

(6)输出

x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] # 这里x的shape为[batch_size, 50, 1024]输出
就是我之前说的如果我们采用mean, 则是在第二个维度求平均,如果不是mean其实第二个维度第一个就是特征表达。最终我们得到[1, 1024]个维度。
x = self.to_latent(x) # 相当于一个容器,把输入都保留下来了。这里我认为相当于保存特征,方便后面finetune操作。

self.mlp_head = nn.Sequential(
  nn.LayerNorm(dim),
  nn.Linear(dim, num_classes)
)
return self.mlp_head(x) # 多层感知机, 输出类别的概率

4. 损失函数

# loss function
criterion = nn.CrossEntropyLoss()

参考

  1. GELU 激活函数
  2. CNN的平移不变性是什么?
  3. VIT Vision Transformer | 先从PyTorch代码了解
  4. 论文笔记 ViT

你可能感兴趣的:(Transformer 在图像中的运用(一)VIT(Transformers for Image Recognition at Scale)论文及代码解读)