【一站式梳理】ViT - Vision Transformer 流程+代码 学习记录

ViT Paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, ICLR 2021.

【一站式梳理】ViT - Vision Transformer 流程+代码 学习记录_第1张图片

目录

ViT Paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, ICLR 2021.

模型流程:

输入:

Transformer Encoder:

MLP head 分类 输出

训练

代码解析:请参考


模型流程:

输入:

1)原始图像:

        将图像按顺序分成指定的Patches,输入Linear Projection后进行Flatten操作。

2)位置编码:

 举个例子讲下transformer的输入输出细节及其他 - 知乎

(pytorch进阶之路)四种Position Embedding的原理及实现-CSDN博客

Learned Positional Embedding ,这个是绝对位置编码,即直接对不同的位置随机初始化一个postion embedding,这个postion embedding作为参数进行训练。(1D PE)

Sinusoidal Position Embedding ,相对位置编码,即三角函数编码。(2D PE)

 ViT使用1D位置编码得到position embedding,因为实验表明使用1DPE和2DPE的对性能影响不大。

import torch
import torch.nn as nn


def create_1d_learnable_embedding(pos_len, dim):
    pos_emb = nn.Embedding(pos_len, dim)
    # 初始化成全0
    nn.init.constant_(pos_emb.weight, 0)
    return pos_emb

 3) Class token

class token的embedding被随机初始化并与pos embedding相加,论文里面是class token是放在首位,也就是第0个位置. VIT中特殊class token的一些问题-CSDN博客

# 随机初始化
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Classifier head
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# 具体forward过程
B = x.shape[0]
x = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1) 
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed

Transformer Encoder:

所有flatten之后的Patches与class token + PE执行stacked或者concatenated输入encoder中。

模型图片来源:霹雳吧啦Wz

【一站式梳理】ViT - Vision Transformer 流程+代码 学习记录_第2张图片

Multi-head Self-attention & 应用到图片_多头注意力机制的图像应用-CSDN博客

10.6. 自注意力和位置编码 — 动手学深度学习 2.0.0 documentation

MLP head 分类 输出

ViT 模型中只使用了 class token 的输出,将其送入 MLP 模块中,去输出最终的分类结果。class token的输出里包含了其他patches的综合编码信息。

训练

  1. 在较大的数据集上预训练;
  2. 在下游数据集上微调用于图像分类。

代码解析:请参考

【超详细】初学者包会的Vision Transformer(ViT)的PyTorch实现代码学习_vit_base_patch16_224_in21k模型-CSDN博客Vision Transformer(ViT)PyTorch代码全解析(附图解)_vit代码-CSDN博客

你可能感兴趣的:(人工智能,计算机视觉)