有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
Vision Transformer 源码解读1
Vision Transformer 源码解读2
Vision Transformer 源码解读3
Vision Transformer 源码解读4
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
class VisionTransformer(nn.Module):
def forward(self, x, labels=None):
x, attn_weights = self.transformer(x)
# print(x.shape)
logits = self.head(x[:, 0])
# print(logits.shape)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
前向传播函数打上断点,开启debug模式,查看数据维度变化:
输入x=[16,3,224,224],依次为batch_size,通道数,图像长和宽
经过self.transformer()后
x=[16,197,768],依次为batch_size,197=196+1其中196为序列长度、1为分类标记,自定义的向量维度
logits = [16,10],依次为batch_size,10分类的分数
class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)
def forward(self, input_ids):
embedding_output = self.embeddings(input_ids)
encoded, attn_weights = self.encoder(embedding_output)
return encoded, attn_weights
前向传播函数打上断点,开启debug模式,查看数据维度变化:
Embeddings输入是彩色图,输出每个位置得到768维向量
Encoder输入与输出都是768维向量
class Embeddings(nn.Module):
def forward(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.hybrid:
x = self.hybrid_model(x)
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
x = torch.cat((cls_tokens, x), dim=1)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
cls_token
:这是一个特殊的“分类(class)标记”,其设计灵感来源于BERT模型中的[CLS]
标记。在Vision Transformer中,这个cls_token
被添加到图像块(patch)嵌入的序列的前面,并且在整个Transformer模型的处理过程中一直携带着。模型的最终目标是使用这个cls_token
的表示(经过Transformer模型的多层处理后的输出)来进行分类任务。换句话说,cls_token
在模型的最后一层的输出被用作图像分类或其他下游任务的基础。这个就是197=196+1的1的由来patch_embeddings
:这是将输入图像分割成多个图像块(patches),然后将每个图像块转换成模型可以处理的嵌入向量的过程。在Vision Transformer中,输入图像首先被划分为多个固定大小的小块,每个小块接着通过一个卷积层(在这个代码中是Conv2d
层)转换成一个嵌入向量。这个卷积层的输出通道数等于模型的隐藏层大小(config.hidden_size
),这样每个图像块就被映射到了一个高维空间,以便后续由Transformer处理。patch_embeddings
实质上是对图像进行了一种“词嵌入”操作,将图像的原始像素值转换为模型可以理解的语义向量Vision Transformer 源码解读1
Vision Transformer 源码解读2
Vision Transformer 源码解读3
Vision Transformer 源码解读4