设计移动设备上的 CNN 模型具有挑战性,需要保证模型小速度快且准确率高,人为地权衡这三方面很困难,有太多种可能结构需要考虑。Google 大脑 AutoML 组提出了一种用于设计资源受限的移动 CNN 模型的神经网络结构搜索方法,将时间延迟信息明确地整合到主要目标中,这样搜索模型可以识别一个网络是否很好地平衡了准确率和时间延迟。
在《MnasNet: Platform-Aware Neural Architecture Search for Mobile》一文中,作者探索了一种使用强化学习设计移动端模型的自动化神经架构搜索方法。为了处理移动端速度限制,将速度信息纳入搜索算法的主要奖励函数中,以便搜索可以识别一个在准确率和速度之间实现良好平衡的模型。如此,MnasNet 能够找到运行速度比 MobileNet V2 快 1.5 倍、比 NASNet 快 2.4 倍的型号,同时达到同样的 ImageNet top-1 准确率。
相比于之前的搜索策略,在这里使用的搜索方法有两点不同:
总体流程主要包括三个部分:一个基于 RNN 的学习和采样模型架构控制器,一个建立和训练模型以获得准确率的训练器,以及一个使用 TensorFlow Lite 测量真实手机上模型速度的推理引擎,作者制定了一个多目标优化问题,旨在实现高准确率和高速,并利用带有定制奖励函数的强化学习算法来寻找帕累托最优解。
一般我们会选定一个目标延迟(最大延迟),在延迟不超过这个最大延迟的情况下尽可能提高所选模型的准确率,即:
但是这种方法只是最大化了单一的变量而没有提供多变量的帕累托最优解,因此,我们将优化目标定义为:
α 和 β 的确定方法:力求在不同的准确率-延迟情况下达到(近乎)相同的 reward。
举例来说,假设 M1 模型的延迟为 l,准确率为 a;M2 模型的延迟为 2l,准确率为 a(1+5%),那么我们应该满足:
R e w a r d ( M 2 ) = a ( 1 + 5 % ) ( 2 l T ) β ≈ R e w a r d ( M 1 ) = a ( l T ) β Reward(M2)=a(1+5\%)(\frac{2l}{T})^{\beta}\approx Reward(M1)=a(\frac{l}{T})^{\beta} Reward(M2)=a(1+5%)(T2l)β≈Reward(M1)=a(Tl)β
解得 β=-0.07。在 MnasNet 论文中,作者使用的是 α=-0.07、β=-0.07 的情况。
α=0、β=-1 时:当 LAT(m)>T 时,reward 不可能有 ACC(m),即不可能大于 LAT(m)≤T 的情况,此时我们称延迟约束为硬约束。
α=-0.07、β=-0.07时:称延迟约束为软约束。
在搜索最优神经架构过程中,我们对以下几个方面进行了寻优操作(基于强化学习):
开发者在 ImageNet 分类和 COCO 物体检测上测试了这种方法的有效性。下图所示为该网络在 ImageNet 上的结果。
在相同的准确度下,MnasNet 模型的运行速度比手工设计的最先进的 MobileNetV2 模型快 1.5 倍,并且比 NASNet 快 2.4 倍,而 NASNet 也是使用架构搜索的方法。在应用压缩和激活优化方法后,MnasNet+SE 模型实现了 76.1% 的 ResNet-50 level top-1 准确率,并且参数数量是 MnasNet 的 1/19,乘加运算数量是 MnasNet 的 1/10。
MnasNet+SE 模型的结构如下图所示:
【注】这里用的激活函数都是 ReLU。
import tensorflow as tf
def conv_bn(x, filters, kernel_size, strides=1, activation=True):
x = tf.keras.layers.Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding='SAME')(x)
x = tf.keras.layers.BatchNormalization()(x)
if activation:
x = tf.keras.layers.Activation('relu')(x)
return x
def depthwiseConv_bn(x, kernel_size, strides):
x = tf.keras.layers.DepthwiseConv2D(kernel_size,
padding='same',
strides=strides)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
return x
def sepConv_bn_noskip(x, filters, kernel_size, strides=1):
x = depthwiseConv_bn(x, kernel_size=kernel_size, strides=strides)
x = conv_bn(x, filters=filters, kernel_size=1, strides=1)
return x
def Squeeze_excitation_layer(x):
inputs = x
squeeze = inputs.shape[-1]/2
excitation = inputs.shape[-1]
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(squeeze)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dense(excitation)(x)
x = tf.keras.layers.Activation('sigmoid')(x)
x = tf.keras.layers.Reshape((1, 1, excitation))(x)
x = inputs * x
return x
def MBConv_idskip(x, filters, kernel_size, strides, t, SE=False):
x_input = x
x = conv_bn(x, filters=x.shape[-1] * t, kernel_size=1, strides=1)
x = depthwiseConv_bn(x, kernel_size=kernel_size, strides=strides)
if SE:
x = Squeeze_excitation_layer(x)
x = conv_bn(x, filters=filters, kernel_size=1, strides=1, activation=False)
if strides==1 and x.shape[3] == x_input.shape[3]:
return tf.keras.layers.add([x_input, x])
else:
return x
def MBConv(x, filters, kernel_size, strides, t, n, SE=False):
x = MBConv_idskip(x, filters, kernel_size, strides, t, SE)
for _ in range(1, n):
x = MBConv_idskip(x, filters, kernel_size, strides=1, t=t, SE=SE)
return x
def MnasNet(x, n_classes=1000):
x = conv_bn(x, 32, kernel_size=3, strides=2)
x = sepConv_bn_noskip(x, filters=16, kernel_size=3, strides=1)
x = MBConv(x, filters=24, kernel_size=3, strides=2, t=6, n=2)
x = MBConv(x, filters=40, kernel_size=5, strides=2, t=3, n=3, SE=True)
x = MBConv(x, filters=80, kernel_size=3, strides=2, t=6, n=4)
x = MBConv(x, filters=96, kernel_size=3, strides=1, t=6, n=2, SE=True)
x = MBConv(x, filters=192, kernel_size=5, strides=2, t=6, n=3, SE=True)
x = MBConv(x, filters=320, kernel_size=3, strides=1, t=6, n=1)
x = conv_bn(x, filters=1152, kernel_size=1, strides=1)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
predictions = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
return predictions
inputs = np.zeros((1, 224, 224, 3), np.float32)
MnasNet(inputs).shape
TensorShape([1, 1000])
Mnasnet论文解析及开源实现