RepViT:从ViT的角度重新审视mobile CNN

文章目录

  • RepViT: Revisiting Mobile CNN From ViT Perspective
    • 摘要
    • 本文方法
    • 代码
    • 实验结果

RepViT: Revisiting Mobile CNN From ViT Perspective

摘要

近年来,与轻量级卷积神经网络(cnn)相比,轻量级视觉变压器(ViTs)在资源受限的移动设备上表现出了更高的性能和更低的延迟。这种改进通常归功于多头自注意模块,它使模型能够学习全局表示。然而,轻量级vit和轻量级cnn之间的架构差异还没有得到充分的研究。在这项研究中,我们重新审视了轻量级cnn的高效设计,并强调了它们在移动设备上的潜力。通过集成轻量级vit的高效架构选择,我们逐步增强了标准轻量级CNN的移动友好性,特别是MobileNetV3。这就产生了一个新的纯轻量级cnn家族,即RepViT。大量的实验表明,RepViT优于现有的轻型vit,并在各种视觉任务中表现出良好的延迟。在ImageNet上,RepViT在iPhone 12上以近1ms的延迟实现了超过80%的top-1精度,据我们所知,这是轻量级模型的第一次。
代码地址
RepViT:从ViT的角度重新审视mobile CNN_第1张图片

本文方法

RepViT:从ViT的角度重新审视mobile CNN_第2张图片
RepViT:从ViT的角度重新审视mobile CNN_第3张图片
图4。(a)表示MobileNetV3的块,具有可选的SE。在(b)中,我们通过重新定位SE,采用结构重参数化来分离token mixer和channel mixer。©涉及在推理阶段将多分支拓扑整合为单个分支。
原始的MobileNetV3块由1x1的扩展卷积,然后是深度卷积和1x1的投影层组成。剩余连接连接输入和输出。此外,挤压和激励模块可以任选地放置在扩展中的深度滤波器之后。从直观上看,1x1展开卷积和1x1投影层实现了通道间的交互,而深度卷积则实现了空间信息的融合.
前者和后者分别对应于通道混频器和令牌混频器。令牌混频器和通道混频器现在在MobileNetV3块中耦合在一起。因此,如图4 (b)所示,我们将深度卷积向上移动以拆分它们。同时,我们采用结构重参数化,在训练时为深度滤波器引入多分支拓扑,以提高性能。挤压和激励模块也被上移到深度滤波器之后,因为它依赖于空间信息交互。
因此,我们成功地分离了MobileNetV3块中的令牌混频器和通道混频器。此外,在推理过程中,如图4 ©所示,令牌混合器的多分支拓扑被合并为单个深度卷积。
RepViT:从ViT的角度重新审视mobile CNN_第4张图片
图5。(a)为MobileNetV3-L中的原始主干,为简单起见,非线性部分略去。我们使用早期卷积作为(b)中的干
RepViT:从ViT的角度重新审视mobile CNN_第5张图片
图6。(a)为MobileNetV3-L区块的原始下采样层。采用RepViT块设计后变为(b)。在©中,分别通过分别使用深度卷积和1x1卷积来调制特征图分辨率和通道维度。通过在前面合并一个RepViT块和在后面合并一个FFN,从而加深了所得的下采样层,增强了其整体架构。为简单起见,省略了非线性

RepViT:从ViT的角度重新审视mobile CNN_第6张图片

代码

class RepViTBlock(nn.Module):
    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
        super(RepViTBlock, self).__init__()
        assert stride in [1, 2]

        self.identity = stride == 1 and inp == oup
        assert(hidden_dim == 2 * inp)

        if stride == 2:
            self.token_mixer = nn.Sequential(
                Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
                Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(oup, 2 * oup, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
                ))
        else:
            assert(self.identity)
            self.token_mixer = nn.Sequential(
                RepVGGDW(inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(inp, hidden_dim, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
                ))

    def forward(self, x):
        return self.channel_mixer(self.token_mixer(x))

class RepVGGDW(torch.nn.Module):
    def __init__(self, ed) -> None:
        super().__init__()
        self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
        self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
        self.dim = ed
    
    def forward(self, x):
        return self.conv(x) + self.conv1(x) + x
    
    @torch.no_grad()
    def fuse(self):
        conv = self.conv.fuse()
        conv1 = self.conv1.fuse()
        
        conv_w = conv.weight
        conv_b = conv.bias
        conv1_w = conv1.weight
        conv1_b = conv1.bias
        
        conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])

        identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])

        final_conv_w = conv_w + conv1_w + identity
        final_conv_b = conv_b + conv1_b

        conv.weight.data.copy_(final_conv_w)
        conv.bias.data.copy_(final_conv_b)
        return conv

实验结果

RepViT:从ViT的角度重新审视mobile CNN_第7张图片
RepViT:从ViT的角度重新审视mobile CNN_第8张图片

你可能感兴趣的:(分割,cnn,人工智能,神经网络)