pytorch增加一维_VIT 三部曲 - 3 vit-pytorch

pytorch增加一维_VIT 三部曲 - 3 vit-pytorch_第1张图片

赵zhijian:VIT 三部曲

赵zhijian:VIT 三部曲 - 2 Vision-Transformer

赵zhijian:VIT 三部曲 - 3 vit-pytorch

模型和代码参考

https://github.com/likelyzhao/vit-pytorch

我们从代码中进行一些详细的分析:

class ViT(nn.Module): 

def __init__(self, *, image_size, patch_size, num_classes, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): 

    super().__init__() 

    assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 


    num_patches = (image_size // patch_size) ** 2 


    hidden_size = channels * patch_size ** 2 


    assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size' 


    self.patch_size = patch_size 


    self.hidden_size = hidden_size 


    self.embedding = nn.Conv2d(channels,hidden_size, patch_size, patch_size) 


    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, hidden_size)) 


    self.cls = nn.Parameter(torch.randn(1, 1, hidden_size)) 


    self.dropout = nn.Dropout(emb_dropout) 


    self.transformer = Encoder(hidden_size, depth, heads, mlp_dim, dropout_rate = dropout) 


    self.to_cls_token = nn.Identity() 


    self.mlp_head = nn.Linear(hidden_size, num_classes) 

首先介绍一些变量:

num_patches 表示切片出来的图片的块的数目,

hidden_size 表示切好的图片拉成一维向量后的特征长度,

pos_embedding 用来将self.hidden_size 长度的特征转换为 channels 维的特征,在这里把这样一个分块的全链接转换为了一个卷积来计算,还是非常的巧妙的

cls 是就是文章里面的class-token 用来做为position 0 的输入transformer 就是用来计算整体的特征提取,

mlp_head 是一个分类器用来进行最终的一个类别的分类.

下面我们介绍整个模型的推理过程:

def forward(self, img, mask = None): 

    x = self.embedding(img) 

    x = rearrange(x, 'b c h w -> b (h w) c') 

    b, n, _ = x.shape 

    cls_tokens = repeat(self.cls, '() n d -> b n d', b = b) 

    x = torch.cat((cls_tokens, x), dim=1) 

    x += self.pos_embedding[:, :(n + 1)] 

    x = self.dropout(x) 

    x = self.transformer(x) 

    x = self.to_cls_token(x[:, 0]) 

    return self.mlp_head(x) 

计算的主流程是把img 输入后先做一个卷积的操作, 然后根据需要reshape 到合理的尺寸,再在channel 维度加上一维cls 的token用来做分类任务, pos_embedding 是 论文中的position_embedding 用来在每个batch上添加学习到位置相关的特征.增加完position embedding 之后就是一个基于self attention的transformer 最后就是拿出整个特征的第一个维度的特征,然后最后跑一个mlp_head 得到最终的分类结果.

特别需要注意的一点是:

google 训练出来的模型的均值方差均为127.5 并不是pytorch 模型中常见的均值方差,因此使用pytorch example 中的 imagenet 测试需要修改pytorch 默认的normalize 成如下所示:

normalize_tf = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) 

实际测试中 Imagenet Top-1 会下降 5% 左右.

总结一下:

整体的VIT 模型并不复杂,经过实际的测试, 最小的VIT-B224-16的模型的推理速度大致与resnet50 相当, 但是实际测试的精度高达 83% 左右,大大优于经过各种骚气优化的resnet50 大概80% 左右,因此大家可以根据需要使用,相信可以获得很高的结果. VIT 三部曲已经完成,希望可以不吝点赞,有哪里写错的也欢迎留言讨论.

你可能感兴趣的:(pytorch增加一维)