一文详解Vision Transformer(附代码)

Transformer 在 NLP 中大获成功,Vision Transformer 则将 Transformer 模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 Transformer 中的注意力机制可以综合考量全局的特征信息。

Vision Transformer 尽力做到在不改变 Transformer 中 Encoder 架构的前提下,直接将其从 NLP 领域迁移到计算机视觉领域中,目的是让原始的 Transformer 模型开箱即用。

干货推荐

  • 浙大博士导师深度整理:Tensorflow 和 Pytorch 的笔记(包含经典项目实战)
  • Python 程序员需要掌握的机器学习“四大名著”发布啦
  • 值得收藏,这份机器学习算法资料着实太香
  • 比 PyTorch 的官方文档还香啊,吃透PyTorch中文版来了
  • 赶快收藏,PyTorch 常用代码段PDF合辑版来了

注意力机制应用

在正式详细介绍 Vision Transformer 之前,先介绍两个注意力机制在计算机视觉中应用的例子。Vision Transformer 并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中 SAGAN 和 AttnGAN 就早已经在 GAN 的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

一文详解Vision Transformer(附代码)_第1张图片

SAGAN 在 GAN 的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。SAGAN 中自注意力机制的操作原理如上图所示。

给定一个 3 通道的输入特征图 ,其中 ,。将 分别输入到三个不同的 的卷积层中,并生成 query 特征图 ,key 特征图 和 value 特征图 。生成 具体的计算过程为,给定三个卷积核 , 和 ,并用这三个卷积核分别与 做卷积运算得到 , 和 ,即:

一文详解Vision Transformer(附代码)_第2张图片

其中 表示卷积运算符号。同理生成 和 的计算过程与 的计算过程类似。然后再利用 和 进行注意力分数的计算得到矩阵 ,其中矩阵 的元素 的计算公式为:

图片

再对矩阵 利用 softmax 函数进行注意力分布的计算得到注意力分布矩阵 ,其中矩阵 的元素 的计算公式为:

一文详解Vision Transformer(附代码)_第3张图片

最后利用注意力分布矩阵 和value特征图 得到最后的输出 ,即:

一文详解Vision Transformer(附代码)_第4张图片

2.2 AttnGAN

一文详解Vision Transformer(附代码)_第5张图片

AttnGAN 通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。AttnGAN 中注意力机制的操作原理如上图所示。

给定输入图像特征向量 和词特征向量 ,其中 ,,。首先利用矩阵 进行线性变换将词特征空间 的向量转换成图像特征空间 的向量,则有:

一文详解Vision Transformer(附代码)_第6张图片

然后再利用转换后的词特征 与图像特征 进行注意力分数的计算得到注意力分数矩阵 ,其中的分量 的计算公式为:

图片

再对矩阵 利用 函数进行注意力分布的计算得到注意力分布矩阵 ,其中矩阵 的元素 的计算公式为:

一文详解Vision Transformer(附代码)_第7张图片

最后利用注意力分布矩阵 和图像特征 得到最后的输出 ,即:

一文详解Vision Transformer(附代码)_第8张图片

Vision Transformer

本节主要详细介绍 Vision Transformer 的工作原理,3.1 节是关于 Vision Transformer 的整体框架,3.2 节是关于 Transformer Encoder 的内部操作细节。对于 Transformer Encoder 中 Multi-Head Attention 的原理本文不会赘述.

不难发现,不管是自然语言处理中的 Transformer,还是计算机视觉中图像生成的 SAGAN,以及文本生成图像的 AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer 整体框架

如果下图所示为 Vision Transformer 的整体框架以及相应的训练流程。

  • 给定一张图片 ,并将它分割成 9 个 patch 分别为 。然后再将这个 9 个 patch 拉平,则有 ;

  • 利用矩阵 将拉平后的向量 经过线性变换得到图像编码向量 ,具体的计算公式为:

    图片

  • 然后将图像编码向量 和类编码向量 分别与对应的位置编进行加和得到输入编码向量,则有:

    图片

  • 接着将输入编码向量输入到 Vision Transformer Encoder 中得到对应的输出 ;

  • 最后将类编码向量 输入全连接神经网络中 MLP 得到类别预测向量 ,并与真实类别向量 计算交叉熵损失得到损失值 loss,利用优化算法更新模型的权重参数。

注意事项:

看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量 ,Vision Transformer Encoder 其它的输出为什么没有输入到 MLP 中?为了回答这个问题,我们令函数 为 Vision Transformer Encoder},则类编码向量 可以表示为:

图片

由上公式可以发现,类编码向量 是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

一文详解Vision Transformer(附代码)_第9张图片

3.2 Transformer Encoder操作原理

如下图所示分别为 Vision Transformer Encoder 模型结构图和原始 Transformer Encoder 的模型结构图。可以直观的发现 Vision Transformer Encoder 和 Transformer Encoder 都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的 Transformer 代码实例中,将以下两种 Encoder 网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。

下图左半部分 Vision Transformer Encoder 具体的操作流程为:

  • 给定输入编码矩阵 ,首先将其进行层归一化得到 ;

  • 利用矩阵 对 进行线性变换得到矩阵 。具体的计算过程为:

    一文详解Vision Transformer(附代码)_第10张图片

  • 将 进行第二次层归一化得到 ,然后再将 输入到全连接神经网络中进行线性变换得到 。最后将 与 进行残差操作得到该 Block 的输出;。一个 Encoder 可以将 个 Block 进行堆叠,最后得到的输出为 。

一文详解Vision Transformer(附代码)_第11张图片

程序代码

Vision Transformer 的作者的本意就是想让在 NLP 中的 Transformer 模型架构做尽可能少的修改可以直接迁移到 CV 中,所以以下程序尽可能保持作者的愿意,并在代码实现了两种Encoder 的网络结构,即 3.2 节图片所示的两个网络结构,一种是最原始的Encoder 网络结构,一种是 Vision Transformer。论文里的 Encoder 的网络结构。

这里需要注意的是,Vision Transformer 里并能没有 Decoder 模块,所以不需要计算 Encoder 和 Decoder 的交叉注意力分布,这就进一步给 Vision Transformer 的编程带来了简便。Vision Transformer的开源代码的网址为:

https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch

import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange

def inputs_deal(inputs):
    return inputs if isinstance(inputs, tuple) else(inputs, inputs)

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N =query.shape[0]
        value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

        # split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape : (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # (N, query_len, heads, head_dim)

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)


    def forward(self, value, key, query, x, type_mode):
        if type_mode == 'original':
            attention = self.attention(value, key, query)
            x = self.dropout(self.norm(attention + x))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm(forward + x))
            return out
        else:
            attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
            x =self.dropout(attention + x)
            forward = self.feed_forward(self.norm(x))
            out = self.dropout(forward + x)
            return out

class TransformerEncoder(nn.Module):
    def __init__(
            self,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout = 0,
            type_mode = 'original'
        ):
        super(TransformerEncoder, self).__init__()
        self.embed_size = embed_size
        self.type_mode = type_mode
        self.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                    )
                for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        for layer in self.layers:
            QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)
            x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)
        return x

class VisionTransformer(nn.Module):
    def __init__(self,
                image_size,
                patch_size,
                num_classes,
                embed_size,
                num_layers,
                heads,
                mlp_dim,
                pool = 'cls',
                channels = 3,
                dropout = 0,
                emb_dropout = 0.1,
                type_mode = 'vit'):
        super(VisionTransformer, self).__init__()
        img_h, img_w = inputs_deal(image_size)
        patch_h, patch_w = inputs_deal(patch_size)

        assert img_h % patch_h == 0 and img_w % patch_w == 0, 'Img dimensions can be divisible by the patch dimensions'

        num_patches = (img_h // patch_h) * (img_w // patch_w)

        patch_size = channels * patch_h * patch_w

        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
            nn.Linear(patch_size, embed_size, bias=False)
        )


        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
        self.dropout = nn.Dropout(emb_dropout)



        self.transformer = TransformerEncoder(embed_size,
                                    num_layers,
                                    heads,
                                    mlp_dim,
                                    dropout)
        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes)
        )

    def forward(self, img):
        x = self.patch_embedding(img)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)


if __name__ == '__main__':
    vit = VisionTransformer(
            image_size = 256,
            patch_size = 16,
            num_classes = 10,
            embed_size = 256,
            num_layers = 6,
            heads = 8,
            mlp_dim = 512,
            dropout = 0.1,
            emb_dropout = 0.1
        )
    img = torch.randn(3, 3, 256, 256)
    pred = vit(img)
    print(pred)

以下代码是利用 Vision Transformer 网络结构训练一个分类 mnist 数据集的主程序代码。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import os


def train():
    batch_size = 4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    epoches = 20
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
    mnist_model = VIT.VisionTransformer(
        image_size = 28,
        patch_size = 7,
        num_classes = 10,
        channels = 1,
        embed_size = 512,
        num_layers = 1,
        heads = 2,
        mlp_dim =1024,
        dropout = 0,
        emb_dropout = 0)
    loss_fn = nn.CrossEntropyLoss()
    mnist_model = mnist_model.to(device)
    opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
    mnist_model.train()
    for epoch in range(epoches):
        total_loss = 0
        corrects = 0
        num = 0
        for batch_X, batch_Y in train_loader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            opitimizer.zero_grad()
            outputs = mnist_model(batch_X)
            _, pred = torch.max(outputs.data, 1)
            loss = loss_fn(outputs, batch_Y)
            loss.backward()
            opitimizer.step()
            total_loss += loss.item()
            corrects = torch.sum(pred == batch_Y.data)
            num += batch_size
            print(epoch, total_loss/float(num), corrects.item()/float(batch_size))

if __name__ == '__main__':
    train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个 Vision Transformer 模型真的是很烧硬件,跟训练一个普通的 CNN 模型相比,训练一个 Vision Transformer 模型更加耗时耗力。

一文详解Vision Transformer(附代码)_第12张图片

技术交流

目前已开通了技术交流群,群友已超过1000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友

  • 方式①、发送如下图片至微信,长按识别,后台回复:加群
  • 方式②、微信搜索公众号:机器学习社区,后台回复:加群
  • 方式③、可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。
    在这里插入图片描述

你可能感兴趣的:(机器学习社区,transformer,深度学习,人工智能)