MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)

文章目录

  • 介绍
  • 网络介绍
    • 总体流程
    • 优化目标
    • 应用神经架构搜索
    • Mnasnet 结构
    • 含有 SE 模块的 Mnasnet 结构
  • 代码实现
  • 参考资料

介绍

设计移动设备上的 CNN 模型具有挑战性,需要保证模型小速度快且准确率高,人为地权衡这三方面很困难,有太多种可能结构需要考虑。Google 大脑 AutoML 组提出了一种用于设计资源受限的移动 CNN 模型的神经网络结构搜索方法,将时间延迟信息明确地整合到主要目标中,这样搜索模型可以识别一个网络是否很好地平衡了准确率和时间延迟。

在《MnasNet: Platform-Aware Neural Architecture Search for Mobile》一文中,作者探索了一种使用强化学习设计移动端模型的自动化神经架构搜索方法。为了处理移动端速度限制,将速度信息纳入搜索算法的主要奖励函数中,以便搜索可以识别一个在准确率和速度之间实现良好平衡的模型。如此,MnasNet 能够找到运行速度比 MobileNet V2 快 1.5 倍、比 NASNet 快 2.4 倍的型号,同时达到同样的 ImageNet top-1 准确率。

网络介绍

总体流程

相比于之前的搜索策略,在这里使用的搜索方法有两点不同:

  • 1、在这里我们使用的是同时考虑准确率和延迟的多目标优化;
  • 2、直接测量实际中在手机端的推理延迟而不是使用 FLOPS 间接表示。

总体流程主要包括三个部分:一个基于 RNN 的学习和采样模型架构控制器,一个建立和训练模型以获得准确率的训练器,以及一个使用 TensorFlow Lite 测量真实手机上模型速度的推理引擎,作者制定了一个多目标优化问题,旨在实现高准确率和高速,并利用带有定制奖励函数的强化学习算法来寻找帕累托最优解。
MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)_第1张图片

  • 首先我们计算 latency (来自于训练时候的时间) 以及网络准确度,共同得到一个 reward,这个 reward 即代表了我们要权衡的两个东西,运算时间和网络准确度;
  • 然后将 reward 返回到一个 controller 中, 这个 controller 应该就是决定了如何对网络进行重构和择优。

优化目标

一般我们会选定一个目标延迟(最大延迟),在延迟不超过这个最大延迟的情况下尽可能提高所选模型的准确率,即:
在这里插入图片描述
但是这种方法只是最大化了单一的变量而没有提供多变量的帕累托最优解,因此,我们将优化目标定义为:
MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)_第2张图片
α 和 β 的确定方法:力求在不同的准确率-延迟情况下达到(近乎)相同的 reward。

举例来说,假设 M1 模型的延迟为 l,准确率为 a;M2 模型的延迟为 2l,准确率为 a(1+5%),那么我们应该满足:
R e w a r d ( M 2 ) = a ( 1 + 5 % ) ( 2 l T ) β ≈ R e w a r d ( M 1 ) = a ( l T ) β Reward(M2)=a(1+5\%)(\frac{2l}{T})^{\beta}\approx Reward(M1)=a(\frac{l}{T})^{\beta} Reward(M2)=a(1+5%)(T2l)βReward(M1)=a(Tl)β
解得 β=-0.07。在 MnasNet 论文中,作者使用的是 α=-0.07、β=-0.07 的情况。

α=0、β=-1 时:当 LAT(m)>T 时,reward 不可能有 ACC(m),即不可能大于 LAT(m)≤T 的情况,此时我们称延迟约束为硬约束。
α=-0.07、β=-0.07时:称延迟约束为软约束。

应用神经架构搜索

MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)_第3张图片
在搜索最优神经架构过程中,我们对以下几个方面进行了寻优操作(基于强化学习):

  • 卷积操作:标准卷积 or 深度级卷积 or MobileNet V2 中的 BottleNeck block;
  • 卷积核尺寸:3x3 or 5x5;
  • 跳跃操作:无跳跃 or 池化跳跃 or 残差跳跃;
  • SE Ratio:0 or 0.25;
  • 卷积核个数:基于 MobileNet V2,是 MobileNet V2 中的 {0.75 or 1.0 or 1.25} 倍;
  • 每个模块内的层数:基于 MobileNet V2,相比于 MobileNet V2 中的层数 {+1 or +0 or -1}。

Mnasnet 结构

MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)_第4张图片
值得注意的地方有:

  • 1、上图中 MBConv 后面跟的数字表示扩展系数(3 或 6),即经过第一个 Conv 1x1 后的通道数与经过前的通道数之比。
  • 2、与 MobileNet V2 类似,当 MBConv 模块是第一次出现时,其中 DwiseConv 的步长可以是任意的(通常是 1 或 2),但后面重复该模块时步长必须设为 1。(步长是 2 时会对 feature map 的形状产生影响)
  • 3、和 MobileNet V2 不同的是,每个 MBConv 模块中的第二个 Conv 1x1 中的卷积核个数并不需要等于输入的通道数,从 (a) 图中也能看出这一点。
  • 4、和 MobileNet V2 一样,Mnasnet 用的激活函数也是 ReLU6(除了之后提及的 SE 模块),且 MBConv 模块中的第二个 Conv 1x1 之后不设置激活函数。

含有 SE 模块的 Mnasnet 结构

开发者在 ImageNet 分类和 COCO 物体检测上测试了这种方法的有效性。下图所示为该网络在 ImageNet 上的结果。
MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)_第5张图片
在相同的准确度下,MnasNet 模型的运行速度比手工设计的最先进的 MobileNetV2 模型快 1.5 倍,并且比 NASNet 快 2.4 倍,而 NASNet 也是使用架构搜索的方法。在应用压缩和激活优化方法后,MnasNet+SE 模型实现了 76.1% 的 ResNet-50 level top-1 准确率,并且参数数量是 MnasNet 的 1/19,乘加运算数量是 MnasNet 的 1/10。

MnasNet+SE 模型的结构如下图所示:
MnasNet 网络原理与 Tensorflow2.0 实现(with SE module)_第6张图片
【注】这里用的激活函数都是 ReLU。

代码实现

import tensorflow as tf

def conv_bn(x, filters, kernel_size, strides=1, activation=True):
    x = tf.keras.layers.Conv2D(filters=filters, 
                               kernel_size=kernel_size, 
                               strides=strides, 
                               padding='SAME')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    if activation:
        x = tf.keras.layers.Activation('relu')(x)
        
    return x

def depthwiseConv_bn(x, kernel_size, strides):

    x = tf.keras.layers.DepthwiseConv2D(kernel_size, 
                                        padding='same', 
                                        strides=strides)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    return x

def sepConv_bn_noskip(x, filters, kernel_size, strides=1):
    
    x = depthwiseConv_bn(x, kernel_size=kernel_size, strides=strides)
    x = conv_bn(x, filters=filters, kernel_size=1, strides=1)
    
    return x

def Squeeze_excitation_layer(x):
    
    inputs = x
    squeeze = inputs.shape[-1]/2
    excitation = inputs.shape[-1]
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(squeeze)(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Dense(excitation)(x)
    x = tf.keras.layers.Activation('sigmoid')(x)
    x = tf.keras.layers.Reshape((1, 1, excitation))(x)
    x = inputs * x

    return x

def MBConv_idskip(x, filters, kernel_size, strides, t, SE=False):
    
    x_input = x
    
    x = conv_bn(x, filters=x.shape[-1] * t, kernel_size=1, strides=1)
    x = depthwiseConv_bn(x, kernel_size=kernel_size, strides=strides)
    if SE:
        x = Squeeze_excitation_layer(x)
    x = conv_bn(x, filters=filters, kernel_size=1, strides=1, activation=False)
    
    if strides==1 and x.shape[3] == x_input.shape[3]:
        return  tf.keras.layers.add([x_input, x])
    else: 
        return x

def MBConv(x, filters, kernel_size, strides, t, n, SE=False):
    
    x = MBConv_idskip(x, filters, kernel_size, strides, t, SE)
    
    for _ in range(1, n):
        x = MBConv_idskip(x, filters, kernel_size, strides=1, t=t, SE=SE)
        
    return x

def MnasNet(x, n_classes=1000):
    
    x = conv_bn(x, 32, kernel_size=3, strides=2)
    x = sepConv_bn_noskip(x, filters=16, kernel_size=3, strides=1)
    x = MBConv(x, filters=24, kernel_size=3, strides=2, t=6, n=2)
    x = MBConv(x, filters=40, kernel_size=5, strides=2, t=3, n=3, SE=True)
    x = MBConv(x, filters=80, kernel_size=3, strides=2, t=6, n=4)
    x = MBConv(x, filters=96, kernel_size=3, strides=1, t=6, n=2, SE=True)
    x = MBConv(x, filters=192, kernel_size=5, strides=2, t=6, n=3, SE=True)
    x = MBConv(x, filters=320, kernel_size=3, strides=1, t=6, n=1)
    
    x = conv_bn(x, filters=1152, kernel_size=1, strides=1)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    predictions = tf.keras.layers.Dense(n_classes, activation='softmax')(x)

    return predictions

inputs = np.zeros((1, 224, 224, 3), np.float32)
MnasNet(inputs).shape
TensorShape([1, 1000])

参考资料

Mnasnet论文解析及开源实现

你可能感兴趣的:(Tensorflow,2.0,深度学习,深度学习,tensorflow,人工智能,算法,神经网络)