# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torchsnooper as ts
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
import models.configs as configs
logger = logging.getLogger(__name__)
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"
为了能够直接观察本节代码,因此在此代码后,进行了调用,创建了输入张量以及输出。
if __name__=="__main__":
config = CONFIGS['ViT-B_16']
config.split = 'overlap'
net = VisionTransformer(config,num_classes=200)
x = torch.rand((2,3,224,224)).type((torch.float32))
y = net(x)
此处有多个注意力抽头,有的文章强调,多个注意力抽头可以将嵌入向量映射到不同的空间,也就是可以关注不同的信息。害,神奇,不知道他能关注姿态么?
class Attention(nn.Module):
def __init__(self, config):
super(Attention, self).__init__()
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
with ts.snoop():
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)#(2,325,12,64)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)#(2,12,325,64)
def forward(self, hidden_states):
with ts.snoop():#torchsnooper.snoop(),打印每一行运行结果,有利于调试
#hidden_states(2, 325, 768),325=324+1,是嵌入向量。
#self.query\key\value是各自的映射,输出为(2,325,768).。此处的768=12*64是12个注意力抽头,每个是64维向量
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
#(2,12,325,64)12表示12个注意力,325是嵌入向量的个数。计算时,计算每个注意力抽头中的不同嵌入向量间的关系,
#最后将12个注意力的输出结果进行合并,因此此处通过transpose_for_scores函数进行reshape
query_layer = self.transpose_for_scores(mixed_query_layer)#(2,12,325,64)
key_layer = self.transpose_for_scores(mixed_key_layer)#(2,12,325,64)
value_layer = self.transpose_for_scores(mixed_value_layer)#(2,12,325,64)
#计算相似度q*kT,输出后两维度的每一行表示所有嵌入向量与当前向量的相似度
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))#(2,12,325,325)
attention_scores = attention_scores / math.sqrt(self.attention_head_size) #q*kT/根号d
attention_probs = self.softmax(attention_scores)
weights = attention_probs#(2,12,325,325)
attention_probs = self.attn_dropout(attention_probs)
#给各个向量分配权重,q*kT*v
context_layer = torch.matmul(attention_probs, value_layer)#(2,12,325,64)
#合并所有的注意力抽头
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() #(2,325,12,64)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)#(2,325,768)
attention_output = self.out(context_layer) #(2,325,768)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
#输出的是各个向量计算多个注意力抽头,分配权重,合并抽头,并映射为嵌入维度的向量、12个heads对应的vectors间的相似度关系
创建以及调用代码:
Block:
self.attn = Attention(config)
------
x, weights = self.attn(x)
#实现的是注意力输出结果,进行映射
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config.transformer["dropout_rate"])
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
with ts.snoop():
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None #混合模型。可能是resnet+ViT?
img_size = _pair(img_size)
patch_size = _pair(config.patches["size"])
if config.split == 'non-overlap':
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])#不重叠的获取,此时分割线处的内容信息就不能完整获得
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)# kernel和stride都是patch_size大小
elif config.split == 'overlap':
n_patches = ((img_size[0] - patch_size[0]) // config.slide_step + 1) * ((img_size[1] - patch_size[1]) // config.slide_step + 1) #重叠获取。巻积核不变,stride变,此时输出的patch数应该会边多
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=(config.slide_step, config.slide_step))
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))#位置信息
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))#每次多出来的那一个嵌入向量
self.dropout = Dropout(config.transformer["dropout_rate"])
def forward(self, x):
with ts.snoop():
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)#(2,1,768)
if self.hybrid:
x = self.hybrid_model(x)#(2, 768, 18, 18)
x = self.patch_embeddings(x)
x = x.flatten(2)#(2,768,324)
x = x.transpose(-1, -2)#(2,324,768)
x = torch.cat((cls_tokens, x), dim=1)#(2,325,768)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings#(2,325,768)
class Block(nn.Module):
def __init__(self, config):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config)
def forward(self, x):
with ts.snoop():
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h#LN 、att、residual
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h #LN、mlp、residual
return x, weights #weights是每个块内注意力的参数,也就是12个注意力抽头的嵌入向量间的余弦相似度矩阵
def load_from(self, weights, n_block):#为看
ROOT = f"Transformer/encoderblock_{n_block}"
with torch.no_grad():
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
self.attn.query.weight.copy_(query_weight)
self.attn.key.weight.copy_(key_weight)
self.attn.value.weight.copy_(value_weight)
self.attn.out.weight.copy_(out_weight)
self.attn.query.bias.copy_(query_bias)
self.attn.key.bias.copy_(key_bias)
self.attn.value.bias.copy_(value_bias)
self.attn.out.bias.copy_(out_bias)
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
self.ffn.fc1.weight.copy_(mlp_weight_0)
self.ffn.fc2.weight.copy_(mlp_weight_1)
self.ffn.fc1.bias.copy_(mlp_bias_0)
self.ffn.fc2.bias.copy_(mlp_bias_1)
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
class Part_Attention(nn.Module):
def __init__(self):
super(Part_Attention, self).__init__()
def forward(self, x):
with ts.snoop():
length = len(x)
last_map = x[0]#(2,12,325,325)第一个权重矩阵
for i in range(1, length):
last_map = torch.matmul(x[i], last_map)
last_map = last_map[:,:,0,1:]#(2,12,324)每个注意力抽头,只要cls_tokens与其他抽头之间的相似度
_, max_inx = last_map.max(2)#获取12个抽头中与cls_tokens最为相似的
return _, max_inx
class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
self.layer = nn.ModuleList()
for _ in range(config.transformer["num_layers"] - 1):
layer = Block(config)
self.layer.append(copy.deepcopy(layer))
self.part_select = Part_Attention()
self.part_layer = Block(config)
self.part_norm = LayerNorm(config.hidden_size, eps=1e-6)
def forward(self, hidden_states):#(2,325,768)
with ts.snoop():
attn_weights = []
for layer in self.layer:#前向计算,并记录weights
hidden_states, weights = layer(hidden_states)
attn_weights.append(weights)
part_num, part_inx = self.part_select(attn_weights)
#part_num(2,12) part_inx(2,12),也就是12个注意力头每一个中的最大值
#理解:因为每一个注意力头代表的是不同空间,代表的是不同的信息,在每个空间中与cls_tokens余弦相似度最大的输出。
part_inx = part_inx + 1#由于hidden_states[0]是cls_tokens,获取part_inx时取的是324个相似度中最大
parts = []
B, num = part_inx.shape
for i in range(B):
parts.append(hidden_states[i, part_inx[i,:]])#获取这些更好表达的注意力向量
parts = torch.stack(parts).squeeze(1)#(2,12,768)
concat = torch.cat((hidden_states[:,0].unsqueeze(1), parts), dim=1)#(2,13,768)和cls_token结合
part_states, part_weights = self.part_layer(concat)#(2,13,768) #(2,12,13,13)送入最后一层,但是怎么体现获取差异性信息确实没有怎么看明白,哈哈,知道得兄弟麻烦告知一下
part_encoded = self.part_norm(part_states)
return part_encoded
class Transformer(nn.Module):
def __init__(self, config, img_size):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config)
def forward(self, input_ids):
with ts.snoop():
embedding_output = self.embeddings(input_ids)
part_encoded = self.encoder(embedding_output)
return part_encoded
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.smoothing_value = smoothing_value
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size)
self.part_head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
with ts.snoop():
part_tokens = self.transformer(x)
part_logits = self.part_head(part_tokens[:, 0])#class_tokens作为输出
if labels is not None:#train
if self.smoothing_value == 0:
loss_fct = CrossEntropyLoss()
else:
loss_fct = LabelSmoothing(self.smoothing_value)#标签平滑,并返回
part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1))
loss = part_loss + contrast_loss
return loss, part_logits
else: #valid
return part_logits
def load_from(self, weights):
with torch.no_grad():
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
posemb_new = self.transformer.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
self.transformer.embeddings.position_embeddings.copy_(posemb)
else:
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)
if self.classifier == "token":
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
for bname, block in self.transformer.encoder.named_children():
if bname.startswith('part') == False:
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)
if self.transformer.embeddings.hybrid:
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=bname, n_unit=uname)
def con_loss(features, labels):
B, _ = features.shape
features = F.normalize(features)
cos_matrix = features.mm(features.t())
pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float()
neg_label_matrix = 1 - pos_label_matrix
pos_cos_matrix = 1 - cos_matrix
neg_cos_matrix = cos_matrix - 0.4
neg_cos_matrix[neg_cos_matrix < 0] = 0
loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum()
loss /= B
return loss
CONFIGS = {
'ViT-B_16': configs.get_b16_config(),
'ViT-B_32': configs.get_b32_config(),
'ViT-L_16': configs.get_l16_config(),
'ViT-L_32': configs.get_l32_config(),
'ViT-H_14': configs.get_h14_config(),
'testing': configs.get_testing(),
}
这篇论文,采用VℹT结构,重点介绍了获取差异性特征以及contrastive loss,后者不必再提。至于获取差异性特征不是很明白。将之前所有的q*kT矩阵相乘,最后挑出cls_token相似度的那一行,每个注意力头中都获取最大的,挑出对应的特征向量,和cls_token整合后送入最后一层。
首先,不懂得为啥子要乘,乘了有什么作用,仅仅因为i这样能涉及到所有的权重?
其次,为什么要挑出与cls_token相似度最大的对应的嵌入向量与cls_token整合,送入最后一层?