蓝图分离卷积BSConv 学习笔记 (附代码)

论文地址:https://arxiv.org/abs/2003.13549

代码地址:https://github.com/zeiss-microscopy/BSConv

1.是什么?

BSConv是深度可分离卷积DSConv的升级版本,它更好地利用内核内部相关性来实现高效分离。具体而言,BSConvU是将一个标准的卷积分解为1x1卷积(PW)和一个逐通道卷积,是深度可分离卷积(DSConv—逐通道、逐点)的逆向版本。此外,BSConv还有一个变体操作—BSConvS。

2.为什么?

受启发于预训练模型的核属性的量化分析:深度方向的强相关性。作者提出一种“蓝图分离卷积”(blueprint separable convolutions, BSConv)作为高效CNN的构建模块。

基于该发现,作者构建了一套理论基础并用于推导如何采用标准OP进行高效实现。更进一步,所提方法为深度分离卷积的应用(深度分离卷积已成为当前主流网络架构的核心模块)提供了系统的理论推导、可解释性以及原因分析。最后,作者揭示了基于深度分离卷积的网络架构(如MobileNet)隐式的依赖于跨核相关性;而所提BSConv则基于核内相关性,故可以为常规卷积提供一种更有效的拆分。

作者通过充分的实验(大尺度分类与细粒度分类)验证了所提BSConv可以明显的提升MobileNet以及其他基于深度分离卷积的架构的性能,而不会引入额外的复杂度。对于细粒度问题,所提方法取得13.7%的性能提升;在ImageNet分类任务,BSConv在ResNet的“即插即用”取得了9.5%的性能提升。

3.怎么样?

3.1网络结构

蓝图分离卷积BSConv 学习笔记 (附代码)_第1张图片

在标准卷积中,每个卷积层对输入张量U\epsilon R^{M*Y*X}进行变化得到输出张量V\epsilon R^{N*Y*X},相应的卷积核F^{(1)},...,F^{(N)},每个卷积核的尺寸为M*K*K。相应的公式可以描述为(图示见下图):

这些卷积核将通过反向传播方式进行优化训练。

预训练CNN中的卷积核可以通过一个模板以及M个因子进行近似。该发现也是本文提的(blueprint separable convolutions,BSConv)的驱动源泉,它滤波器卷积提供另一种定义方式。

尽管上述定义为滤波器添加了硬约束,但作者通过实验表明:相比标准卷积,所提方法可以达到相同甚至更优的性能。另外,需要注意的是:标准卷积的可训练参数为M\cdot N\cdot K^{2},而所提方法仅具有N\cdot K^{2}+ M\cdot N个可训练参数。

3.2 Variants and Implementations

前面已经介绍了BSConv的卷积核信息,它的权值M\cdot N可以组合为矩阵W=(w_{n,m})。此时根据W的学习方式不同,又有两种不同的变种。

  • BSConv-U:在大多场景下,权值W可以不进行任何约束进行训练学习。此时,公式(1)可以转换为如下公式。此时,常规卷积1*1可以解耦为卷积K*K深度卷积,见下图。

蓝图分离卷积BSConv 学习笔记 (附代码)_第2张图片

 对于这种形式的CNN架构,作者发现:权值W在行方向存在高度相关性。这为进一步的正则化与参数降低提供了可能。也就引出了下面将要介绍的BSConv-S变种。

  • BSConv-S:基于前述发现,作者对权值W进行低秩分解:W = W^{A}*W^{B}。其中W^{A}=N*M',W^{B}=M'*M,M'=[p\cdot M],p\epsilon (0.0,1.0).而后,经过一些列的变换处理,最终BSConv的公式转换为下面的公式。此时,常规卷积可以解耦为1*1卷积+1*1卷积+K*K深度卷积,见上图。

3.3  Discussion

前面已经介绍了BSConv的两种变种,这里将对比分析一下上述两种变种与已有模块的区别和联系。

  • BSConv-U是一种逆深度分类卷积。两者的出发点有一些区别:DSConv实施了跨核相关性,而BSConv-U则实施了核内相关性。已有研究表明:尽管跨核相关性与核内相关性都是有效假设,但核内相关性更有优势,对于高效分离更具潜力。需要注意的是:卷积后不跟激活函数或者规范化函数。

  • BSConv-S是一种具有正交正则化功能的转移线性瓶颈模块。线性瓶颈层是当前高效网络MobileNet的核心模块,它由pointwise、depthwise、pointwise级联构成,而BSConv-S则是由pointwise, pointwise, depthwise级联构成。从中可以看到两者之间的紧密联系。此外,需要注意的是:与前者相同,激活函数与规范化函数不在模块内添加

3.4代码实现

class BSConvU(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", with_bn=False, bn_kwargs=None):
        super().__init__()

        # check arguments
        if bn_kwargs is None:
            bn_kwargs = {}

        # pointwise
        self.add_module("pw", torch.nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=False,
        ))

        # batchnorm
        if with_bn:
            self.add_module("bn", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))

        # depthwise
        self.add_module("dw", torch.nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=out_channels,
                bias=bias,
                padding_mode=padding_mode,
        ))


class BSConvS(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", p=0.25, min_mid_channels=4, with_bn=False, bn_kwargs=None):
        super().__init__()

        # check arguments
        assert 0.0 <= p <= 1.0
        mid_channels = min(in_channels, max(min_mid_channels, math.ceil(p * in_channels)))
        if bn_kwargs is None:
            bn_kwargs = {}

        # pointwise 1
        self.add_module("pw1", torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=mid_channels,
            kernel_size=(1, 1),
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=False,
        ))

        # batchnorm
        if with_bn:
            self.add_module("bn1", torch.nn.BatchNorm2d(num_features=mid_channels, **bn_kwargs))

        # pointwise 2
        self.add_module("pw2", torch.nn.Conv2d(
            in_channels=mid_channels,
            out_channels=out_channels,
            kernel_size=(1, 1),
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=False,
        ))

        # batchnorm
        if with_bn:
            self.add_module("bn2", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))

        # depthwise
        self.add_module("dw", torch.nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=out_channels,
            bias=bias,
            padding_mode=padding_mode,
        ))

    def _reg_loss(self):
        W = self[0].weight[:, :, 0, 0]
        WWt = torch.mm(W, torch.transpose(W, 0, 1))
        I = torch.eye(WWt.shape[0], device=WWt.device)
        return torch.norm(WWt - I, p="fro")


class BSConvS_ModelRegLossMixin():
    def reg_loss(self, alpha=0.1):
        loss = 0.0
        for sub_module in self.modules():
            if hasattr(sub_module, "_reg_loss"):
                loss += sub_module._reg_loss()
        return alpha * loss

参考:

深度分离卷积重思考:BSConv

轻量化神经网络卷积设计研究进展

你可能感兴趣的:(深度学习,人工智能,笔记,网络,学习)