有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
Vision Transformer 源码解读1
Vision Transformer 源码解读2
Vision Transformer 源码解读3
Vision Transformer 源码解读4
class Encoder(nn.Module):
def forward(self, hidden_states):
# print(hidden_states.shape)
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
hidden_states.shape = torch.Size([16, 197, 768])
encoded.shape = torch.Size([16, 197, 768])
这段代码实现了一个编码器,能够处理序列数据并可选择性地输出每层的注意力权重,通过层层叠加的Block和最终的归一化处理,它能够有效地学习输入数据的特征表示
class Block(nn.Module):
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights
class Mlp(nn.Module):
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
最重要的Attention类
class Attention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)#Linear(in_features=768, out_features=768, bias=True)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
x------torch.Size([16, 197, 12, 64])
new_x_shape------torch.Size([16, 197, 12, 64])
x.permute(0, 2, 1, 3)------torch.Size([16, 12, 197, 64])
transpose_for_scores函数将输入张量的最后一个维度拆分为num_attention_heads
(注意力头数)和attention_head_size
(每个头的大小),然后将维度重新排列以满足矩阵乘法的需求
Vision Transformer 源码解读1
Vision Transformer 源码解读2
Vision Transformer 源码解读3
Vision Transformer 源码解读4