Separable Self-attention for Mobile Vision Transformers

paper链接: https://arxiv.org/pdf/2206.02680.pdf

code链接: https://github.com/apple/ml-cvnets

Separable Self-attention for Mobile Vision Transformers

  • (一)、引言
  • (二)、实现细节
  • (三)、实验
    • (一)、图像分类
    • (二)、语义分割
    • (三)、目标检测

(一)、引言

移动视觉transformers(MobileViT)可以在多个移动视觉任务中实现最先进的性能,包括分类和检测。虽然这些模型的参数更少,但与基于卷积神经网络的模型相比,它们具有较高的延迟。MobileViT的主要效率瓶颈是transformers中的多头自注意(MHA),它需要 O ( k 2 ) O(k^2) O(k2)时间复杂度,相对于令牌(或补丁)k的数量。此外,MHA需要复杂的操作(例如批次矩阵乘法)来计算自注意,影响资源受限设备的延迟。

本文提出了一种具有线性复杂度的可分离自注意方法,即 O ( k ) O(k) O(k)。该方法的一个简单而有效的特点是,它使用元素操作来计算自注意,这使它成为资源受限设备的一个很好的选择。改进后的模型MobileViTv2在一些移动视觉任务上是最先进的,包括ImageNet对象分类和MS-COCO对象检测。

MHA允许令牌(或补丁)彼此交互,并且是学习全局表示的关键。然而,Transformers中的自注意的复杂性是 O ( k 2 ) O(k^2) O(k2),即,它是关于令牌(或补丁)k的二次。除此之外,计算成本高的操作(例如,批次矩阵乘法;,如下图所示)来计算MHA中的注意矩阵。这尤其涉及到在资源受限的设备上部署基于vit的模型,因为这些设备具有较低的计算能力、限制性的内存约束和有限的功率预算。
Separable Self-attention for Mobile Vision Transformers_第1张图片
本文提出了一种求解ransformers MHA瓶颈问题的O(k)复杂度可分离自注意方法。为了有效地推断所提出的自注意方法还用基于元素的操作(例如,求和和乘法)取代了MHA中计算成本高昂的操作(例如,批矩阵乘法)。在标准视觉数据集和任务上的实验结果证明了所提方法的有效性(下图)。
Separable Self-attention for Mobile Vision Transformers_第2张图片

(二)、实现细节

MobileViT是一个混合网络,结合了CNN和VIT的优势。MobileViT将ransformers视为卷积,这允许它利用卷积(例如,归纳偏差)和ransformers(例如,远程依赖)的优点来为移动设备构建轻量级网络。尽管与轻量级CNN相比,MobileViT网络具有更少的参数并提供更好的性能(例如,MobileNets),但它们具有较高的延迟。MobileViT的主要效率瓶颈是多头自注意(MHA;下图 a)。

MHA使用缩放点积注意来捕获k个令牌(或补丁)之间的上下文关系。然而,由于MHA的时间复杂度为 O ( k 2 ) O(k^2) O(k2),因此成本较高。这种二次成本是具有大量令牌k的Transformers的瓶颈。此外,MHA使用计算和内存密集型操作(例如,批次矩阵乘法和softmax用于计算注意矩阵);这可能成为资源受限设备的瓶颈。为了解决MHA在资源受限设备上高效推理的局限性,本文引入了具有线性复杂度的可分离自注意(下图c)。
如下图所示,可分离自注意方法的主要思想是计算关于潜在令牌L的上下文分数。这些分数然后用于重新加权输入令牌,并产生一个上下文向量,对全局信息进行编码。
Separable Self-attention for Mobile Vision Transformers_第3张图片
由于自注意是根据潜在令牌计算的,所提出的方法可以将Transformers中自注意的复杂性降低k倍。所提出的方法的一个简单而有效的特点是,它使用元素的操作(例如,求和和乘法)来实现,使其成为资源受限设备的良好选择。
Separable Self-attention for Mobile Vision Transformers_第4张图片
可分离自注意的结构是由MHA启发的。与MHA类似,输入x使用三个分支进行处理,即输入I、键K和值v。输入分支I使用权重 W I ∈ R d W_I∈R^d WIRd的线性层将x中的每个d维令牌映射到一个标量。权重 W I W_I WI作为图4b中的潜在节点L。这个线性投影是一个内积运算,计算潜在标记L和x之间的距离,从而得到一个k维向量。然后对这个k维向量进行softmax运算,得到上下文分数 c s ∈ R k c_s∈R^k csRk。与针对所有k个令牌计算每个令牌的注意力(或上下文)得分的Transformers不同,提议的方法只计算关于潜在标记l的上下文分数,这将计算注意力(或上下文)分数的成本从 O ( k 2 ) O(k^2) O(k2)降低到O(k)。上下文分数 c s c_s cs用于计算上下文向量 c v c_v cv。具体来说,输入x使用权值 W K ∈ R d × d W_K∈R^{d×d} WKRd×d的关键分支K线性投影到d维空间,从而产生输出 x K ∈ R k × d x_K∈R^{k×d} xKRk×d。然后,上下文向量 c v ∈ R d c_v∈R^d cvRd被计算为 x K x_K xK的加权和:
在这里插入图片描述
上下文向量 c v c_v cv类似于MHA中的注意力矩阵a,在某种意义上,它也编码来自输入x中的所有标记的信息,但计算成本很低。

c v c_v cv中编码的上下文信息与x中的所有令牌共享。为此,输入x使用权重为 W V ∈ R d × d W_V∈R^{d×d} WVRd×d的值分支V线性投影到d维空间,然后通过ReLU激活产生输出 x V ∈ R k × d x_V∈R^{k×d} xVRk×d。cv中的上下文信息然后通过元素乘操作传播到 x V x_V xV。然后将结果输出馈送到权重为 W O ∈ R d × d W_O∈R^{d×d} WORd×d的另一个线性层,以产生最终输出 y ∈ R k × d y∈R^{k×d} yRk×d

数学上,可分离自我注意可以定义为:
Separable Self-attention for Mobile Vision Transformers_第5张图片
mobilevit2的基本架构如下:
Separable Self-attention for Mobile Vision Transformers_第6张图片
Separable Self-attention for Mobile Vision Transformers_第7张图片

Separable Self-attention for Mobile Vision Transformers_第8张图片

(三)、实验

(一)、图像分类

Separable Self-attention for Mobile Vision Transformers_第9张图片
Separable Self-attention for Mobile Vision Transformers_第10张图片

(二)、语义分割

Separable Self-attention for Mobile Vision Transformers_第11张图片

(三)、目标检测

Separable Self-attention for Mobile Vision Transformers_第12张图片
实现代码如下:
更多细节见原文以及源码部分



        self.qkv_proj = ConvLayer(
            opts=opts,
            in_channels=embed_dim,
            out_channels=1 + (2 * embed_dim),
            bias=bias,
            kernel_size=1,
            use_norm=False,
            use_act=False,
        )

        qkv = self.qkv_proj(x)

        # Project x into query, key and value
        # Query --> [B, 1, P, N]
        # value, key --> [B, d, P, N]
        query, key, value = torch.split(
            qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1
        )

        # apply softmax along N dimension
        context_scores = F.softmax(query, dim=-1)
        # Uncomment below line to visualize context scores
        # self.visualize_context_scores(context_scores=context_scores)
        context_scores = self.attn_dropout(context_scores)

        # Compute context vector
        # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N]
        context_vector = key * context_scores
        # [B, d, P, N] --> [B, d, P, 1]
        context_vector = torch.sum(context_vector, dim=-1, keepdim=True)

        # combine context vector with values
        # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
        out = F.relu(value) * context_vector.expand_as(value)
        out = self.out_proj(out)
        return out

你可能感兴趣的:(Transformer,深度学习,人工智能,神经网络)