目录
ViT Paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, ICLR 2021.
模型流程:
输入:
Transformer Encoder:
MLP head 分类 输出
训练
代码解析:请参考
1)原始图像:
将图像按顺序分成指定的Patches,输入Linear Projection后进行Flatten操作。
2)位置编码:
举个例子讲下transformer的输入输出细节及其他 - 知乎
(pytorch进阶之路)四种Position Embedding的原理及实现-CSDN博客
Learned Positional Embedding ,这个是绝对位置编码,即直接对不同的位置随机初始化一个postion embedding,这个postion embedding作为参数进行训练。(1D PE)
Sinusoidal Position Embedding ,相对位置编码,即三角函数编码。(2D PE)
ViT使用1D位置编码得到position embedding,因为实验表明使用1DPE和2DPE的对性能影响不大。
import torch
import torch.nn as nn
def create_1d_learnable_embedding(pos_len, dim):
pos_emb = nn.Embedding(pos_len, dim)
# 初始化成全0
nn.init.constant_(pos_emb.weight, 0)
return pos_emb
3) Class token
class token的embedding被随机初始化并与pos embedding相加,论文里面是class token是放在首位,也就是第0个位置. VIT中特殊class token的一些问题-CSDN博客
# 随机初始化
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Classifier head
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# 具体forward过程
B = x.shape[0]
x = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
所有flatten之后的Patches与class token + PE执行stacked或者concatenated输入encoder中。
模型图片来源:霹雳吧啦Wz
Multi-head Self-attention & 应用到图片_多头注意力机制的图像应用-CSDN博客
10.6. 自注意力和位置编码 — 动手学深度学习 2.0.0 documentation
ViT 模型中只使用了 class token 的输出,将其送入 MLP 模块中,去输出最终的分类结果。class token的输出里包含了其他patches的综合编码信息。
【超详细】初学者包会的Vision Transformer(ViT)的PyTorch实现代码学习_vit_base_patch16_224_in21k模型-CSDN博客Vision Transformer(ViT)PyTorch代码全解析(附图解)_vit代码-CSDN博客