论文阅读-MobileViT

论文地址:MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
代码:pytorch代码
作者单位:Apple
参考: https://mp.weixin.qq.com/s/HggXYrhMLqVqnoIcsCOBGg

摘要:

轻量级卷积神经网络(CNN)实际上是用于移动视觉任务的。他们的空间归纳偏差使他们能够在不同的视觉任务中以较少的参数学习表征。然而,这些网络在空间上是局部的。为了学习全局表征,采用了基于自注意的视觉变换器(VIT)。与CNN不同,VIT是heavy-weight。在本文中,我们提出了以下问题:是否有可能结合CNN和VIT的优势,为移动视觉任务构建一个轻量级和低延迟的网络?

为此,我们介绍了一种用于移动设备的轻量级和通用视觉转换器MobileViT。MobileViT从另一个角度介绍了使用transformers进行全局信息处理的方法,即以Transformer作为卷积处理信息。

我们的结果表明,MobileViT在不同的任务和数据集上显著优于基于CNN和ViT的网络。在ImageNet-1k数据集上,MobileViT在大约600万个参数的情况下达到了78.4%的Top-1准确率,对于相同数量的参数,比MobileNetv3和DeiT的准确率分别高出3.2%和6.2%。在MS-COCO目标检测任务中,在参数数量相近的情况下,MobileViT比MobileNetv3的准确率高5.7%。
总结一下:
摘要介绍的意思是,CNN只能够获取局部特征,而transformer可以获取全局特征,因此ViT得以提出,但是transformer不具备空间归纳偏差(spatial inductive biases),因此ViT训练时需要更大的参数学习特征。因此ViT是heavy-weight。
于是便提出了MobileViT模型,将CNN的空间归纳偏差和ViT的全局特征优势相结合。

Introduction(太长缩略版)

该模块介绍了将CNN和transformer结合的一些例子,以及本文的目的。
由于FLOPs在计算时忽略了内存访问,并行度和平台特性,因此,我们计算得到的FLOPs一般是没太大用处的,故而本文的重心不是对FLOPs进行优化,而是对移动视觉任务设计轻量级、通用和低延迟网络。
CNN特点:空间感应偏差、对数据增强的敏感性低
ViT特点:输入自适应加权、全局处理
因此MobileViT以张量有效地编码局部和全局信息(图1b)。与ViT及其变体(有卷积和无卷积)不同,MobileViT从不同的角度学习全局表示。标准卷积涉及三个操作:展开、局部处理和折叠。MobileViT有助于它用更少的参数和简单的训练方法(例如,基本增强)学习更好的表示。

相关工作

摘自链接 https://mp.weixin.qq.com/s/HggXYrhMLqVqnoIcsCOBGg
Vision Transformer应用于大尺度图像识别,结果表明,在超大尺度数据集(如JFT-300M)下,ViTs可以实现CNN级的精度,而不存在图像特异性的归纳偏差。

通过广泛的数据增强、大量的L2正则化和蒸馏,可以在ImageNet数据集上训练ViT,以实现CNN级别的性能。然而,与CNN不同的是ViT的优化性能不佳,而且很难训练。

后续研究表明,这种较差的优化性是由于ViT缺乏空间归纳偏差所导致的。在ViT中使用卷积来合并这种偏差可以提高其稳定性和性能。人们探索了不同的设计来获得卷积和Transformer的好处。

例如,ViT-C为ViT前阶段增加了一个卷积Backbone。

CvT改进了Transformer的Multi-Head Attention,并使用深度可分离卷积代替线性投影。

BoTNet用Multi-Head Attention取代了ResNet Bottleneck unit的标准卷积。

ConViT采用gated positional Self-Attention的soft convolutional归纳偏差。

PiT使用基于深度卷积的池化层扩展了ViT。

虽然这些模型使用data augmentation可以达到与CNN差不多的性能,但是这些模型大多数是heavy-weight。例如,PiT和CvT比EfficientNet多学习6.1倍和1.7倍的参数,在ImageNet-1k数据集上取得了相似的性能(top-1的准确率约为81.6%)。

此外,当这些模型被缩小以构建轻量级ViT模型时,它们的性能明显比轻量级CNN性能差。对于大约600万的参数预算,PiT的ImageNet-1k精度比MobileNetv3低2.2%。
如何结合卷积和变压器的力量来构建轻量级网络的移动视觉任务?

先看下MobileViT所带来的效果论文阅读-MobileViT_第1张图片

论文重点部分算法讲解:

论文阅读-MobileViT_第2张图片
代码剖析部分:
首先看下ViT的代码部分吧:
参考 https://zhuanlan.zhihu.com/p/361686988

// An highlighted block
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # 正则化
        self.fn = fn  # 具体的操作
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 前向传播
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    # attention
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads  # 计算最终进行全连接操作时输入神经元的个数
        project_out = not (heads == 1 and dim_head == dim)  # 多头注意力并且输入和输出维度相同时为True

        self.heads = heads  # 多头注意力中“头”的个数
        self.scale = dim_head ** -0.5  # 缩放操作,论文 Attention is all you need 中有介绍

        self.attend = nn.Softmax(dim = -1)  # 初始化一个Softmax操作
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 对QKV三组向量先进性线性操作

        # 线性全连接,如果不是多头或者输入输出维度不相等,进行空操作
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads  # 获得输入x的维度和多头注意力的“头”数
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 先对QKV进行线性操作,然后chunk乘三三份
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)  # 整理维度,获得QKV

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale  # Q, K 向量先做点乘,来计算相关性,然后除以缩放因子

        attn = self.attend(dots)  # 做Softmax运算

        out = einsum('b h i j, b h j d -> b h i d', attn, v)  # Softmax运算结果与Value向量相乘,得到最终结果
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重新整理维度
        return self.to_out(out)  # 做线性的全连接操作或者空操作(空操作直接输出out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])  # Transformer包含多个编码器的叠加
        for _ in range(depth):
            # 编码器包含两大块:自注意力模块和前向传播模块
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),  # 多头自注意力模块
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))  # 前向传播模块
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            # 自注意力模块和前向传播模块都使用了残差的模式
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'  # 保证一定能够完整切块
        num_patches = (image_size // patch_size) ** 2  # 获取图像切块的个数
        patch_dim = channels * patch_size ** 2  # 线性变换时的输入大小,即每一个图像宽、高、通道的乘积
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'  # 池化方法必须为cls或者mean

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),  # 将批量为b通道为c高为h*p1宽为w*p2的图像转化为批量为b个数为h*w维度为p1*p2*c的图像块
                                                                                                    # 即,把b张c通道的图像分割成b*(h*w)张大小为P1*p2*c的图像块
                                                                                                    # 例如:patch_size为16  (8, 3, 48, 48)->(8, 9, 768)
            nn.Linear(patch_dim, dim),  # 对分割好的图像块进行线性处理(全连接),输入维度为每一个小块的所有像素个数,输出为dim(函数传入的参数)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置编码,获取一组正态分布的数据用于训练
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 分类令牌,可训练
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer模块

        self.pool = pool
        self.to_latent = nn.Identity()  # 占位操作

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 正则化
            nn.Linear(dim, num_classes)  # 线性输出
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)  # 切块操作,shape (b, n, dim),b为批量,n为切块数目,dim为最终线性操作时输入的神经元个数
        b, n, _ = x.shape  # shape (b, n, 1024)

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)  # 分类令牌,将self.cls_token(形状为1, 1, dim)赋值为shape (b, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)  # 将分类令牌拼接到输入中,x的shape (b, n+1, 1024)
        x += self.pos_embedding[:, :(n + 1)]  # 进行位置编码,shape (b, n+1, 1024)
        x = self.dropout(x)

        x = self.transformer(x)  # transformer操作

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)  # 线性输出

看完ViT代码,再看下这里,论文阅读-MobileViT_第3张图片

from torch import nn
import torch
from torch.nn.modules import conv
from torch.nn.modules.conv import Conv2d
from einops import rearrange



def conv_bn(inp,oup,kernel_size=3,stride=1):
    return nn.Sequential(
        nn.Conv2d(inp,oup,kernel_size=kernel_size,stride=stride,padding=kernel_size//2),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

class PreNorm(nn.Module):
    def __init__(self,dim,fn):
        super().__init__()
        self.ln=nn.LayerNorm(dim)
        self.fn=fn
    def forward(self,x,**kwargs):
        return self.fn(self.ln(x),**kwargs)

class FeedForward(nn.Module):
    def __init__(self,dim,mlp_dim,dropout) :
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,mlp_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim,dim),
            nn.Dropout(dropout)
        )
    def forward(self,x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self,dim,heads,head_dim,dropout):
        super().__init__()
        inner_dim=heads*head_dim
        project_out=not(heads==1 and head_dim==dim)

        self.heads=heads
        self.scale=head_dim**-0.5

        self.attend=nn.Softmax(dim=-1)
        self.to_qkv=nn.Linear(dim,inner_dim*3,bias=False)
        
        self.to_out=nn.Sequential(
            nn.Linear(inner_dim,dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self,x):
        qkv=self.to_qkv(x).chunk(3,dim=-1)
        q,k,v=map(lambda t:rearrange(t,'b p n (h d) -> b p h n d',h=self.heads),qkv)
        dots=torch.matmul(q,k.transpose(-1,-2))*self.scale
        attn=self.attend(dots)
        out=torch.matmul(attn,v)
        out=rearrange(out,'b p h n d -> b p n (h d)')
        return self.to_out(out)





class Transformer(nn.Module):
    def __init__(self,dim,depth,heads,head_dim,mlp_dim,dropout=0.):
        super().__init__()
        self.layers=nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim,Attention(dim,heads,head_dim,dropout)),
                PreNorm(dim,FeedForward(dim,mlp_dim,dropout))
            ]))


    def forward(self,x):
        out=x
        for att,ffn in self.layers:
            out=out+att(out)
            out=out+ffn(out)
        return out

class MobileViTAttention(nn.Module):
    def __init__(self,in_channel=3,dim=512,kernel_size=3,patch_size=7,depth=3,mlp_dim=1024):
        super().__init__()
        self.ph,self.pw=patch_size,patch_size
        self.conv1=nn.Conv2d(in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)
        self.conv2=nn.Conv2d(in_channel,dim,kernel_size=1)

        self.trans=Transformer(dim=dim,depth=depth,heads=8,head_dim=64,mlp_dim=mlp_dim)

        self.conv3=nn.Conv2d(dim,in_channel,kernel_size=1)
        self.conv4=nn.Conv2d(2*in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)

    def forward(self,x):
        y=x.clone() #bs,c,h,w

        ## Local Representation
        y=self.conv2(self.conv1(x)) #bs,dim,h,w

        ## Global Representation
        _,_,h,w=y.shape
        y=rearrange(y,'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim',ph=self.ph,pw=self.pw) #bs,h,w,dim
        y=self.trans(y)
        y=rearrange(y,'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)',ph=self.ph,pw=self.pw,nh=h//self.ph,nw=w//self.pw) #bs,dim,h,w

        ## Fusion
        y=self.conv3(y) #bs,dim,h,w
        y=torch.cat([x,y],1) #bs,2*dim,h,w
        y=self.conv4(y) #bs,c,h,w

        return y


class MV2Block(nn.Module):
    def __init__(self,inp,out,stride=1,expansion=4):
        super().__init__()
        self.stride=stride
        hidden_dim=inp*expansion
        self.use_res_connection=stride==1 and inp==out

        if expansion==1:
            self.conv=nn.Sequential(
                nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=self.stride,padding=1,groups=hidden_dim,bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                nn.Conv2d(hidden_dim,out,kernel_size=1,stride=1,bias=False),
                nn.BatchNorm2d(out)
            )
        else:
            self.conv=nn.Sequential(
                nn.Conv2d(inp,hidden_dim,kernel_size=1,stride=1,bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=1,padding=1,groups=hidden_dim,bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                nn.Conv2d(hidden_dim,out,kernel_size=1,stride=1,bias=False),
                nn.SiLU(),
                nn.BatchNorm2d(out)
            )
    def forward(self,x):
        if(self.use_res_connection):
            out=x+self.conv(x)
        else:
            out=self.conv(x)
        return out

class MobileViT(nn.Module):
    def __init__(self,image_size,dims,channels,num_classes,depths=[2,4,3],expansion=4,kernel_size=3,patch_size=2):
        super().__init__()
        ih,iw=image_size,image_size
        ph,pw=patch_size,patch_size
        assert iw%pw==0 and ih%ph==0

        self.conv1=conv_bn(3,channels[0],kernel_size=3,stride=patch_size)
        self.mv2=nn.ModuleList([])
        self.m_vits=nn.ModuleList([])


        self.mv2.append(MV2Block(channels[0],channels[1],1))
        self.mv2.append(MV2Block(channels[1],channels[2],2))
        self.mv2.append(MV2Block(channels[2],channels[3],1))
        self.mv2.append(MV2Block(channels[2],channels[3],1)) # x2
        self.mv2.append(MV2Block(channels[3],channels[4],2))
        self.m_vits.append(MobileViTAttention(channels[4],dim=dims[0],kernel_size=kernel_size,patch_size=patch_size,depth=depths[0],mlp_dim=int(2*dims[0])))
        self.mv2.append(MV2Block(channels[4],channels[5],2))
        self.m_vits.append(MobileViTAttention(channels[5],dim=dims[1],kernel_size=kernel_size,patch_size=patch_size,depth=depths[1],mlp_dim=int(4*dims[1])))
        self.mv2.append(MV2Block(channels[5],channels[6],2))
        self.m_vits.append(MobileViTAttention(channels[6],dim=dims[2],kernel_size=kernel_size,patch_size=patch_size,depth=depths[2],mlp_dim=int(4*dims[2])))

        
        self.conv2=conv_bn(channels[-2],channels[-1],kernel_size=1)
        self.pool=nn.AvgPool2d(image_size//32,1)
        self.fc=nn.Linear(channels[-1],num_classes,bias=False)

    def forward(self,x):
        y=self.conv1(x) #
        y=self.mv2[0](y)
        y=self.mv2[1](y) #
        y=self.mv2[2](y)
        y=self.mv2[3](y)
        y=self.mv2[4](y) #
        y=self.m_vits[0](y)

        y=self.mv2[5](y) #
        y=self.m_vits[1](y)

        y=self.mv2[6](y) #
        y=self.m_vits[2](y)

        y=self.conv2(y)
        y=self.pool(y).view(y.shape[0],-1) 
        y=self.fc(y)
        return y

def mobilevit_xxs():
    dims=[60,80,96]
    channels= [16, 16, 24, 24, 48, 64, 80, 320]
    return MobileViT(224,dims,channels,num_classes=1000)

def mobilevit_xs():
    dims = [96, 120, 144]
    channels = [16, 32, 48, 48, 64, 80, 96, 384]
    return MobileViT(224, dims, channels, num_classes=1000)

def mobilevit_s():
    dims = [144, 192, 240]
    channels = [16, 32, 64, 64, 96, 128, 160, 640]
    return MobileViT(224, dims, channels, num_classes=1000)


def count_paratermeters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == '__main__':
    input=torch.randn(1,3,224,224)

    ### mobilevit_xxs
    mvit_xxs=mobilevit_xxs()
    out=mvit_xxs(input)
    print(out.shape)

    ### mobilevit_xs
    mvit_xs=mobilevit_xs()
    out=mvit_xs(input)
    print(out.shape)


    ### mobilevit_s
    mvit_s=mobilevit_s()
    out=mvit_s(input)
    print(out.shape)

你可能感兴趣的:(#,目标检测,pytorch,深度学习,神经网络)