各位同学好,今天和大家分享一下如何使用 Tensorflow 复现轻量化神经网络模型 MobileNetV1。为了能将神经网络模型用于移动端(手机)和终端(安防监控、无人驾驶)的实时计算,通常这些设备计算能力有限,因此我们需要减少模型参数量、减小计算量、更少的内存访问量、更少的能耗。
下图通过比较GPU和CPU上前向传播的耗时分布,可见,卷积层占用大部分时间。并且 Batch Size 越大,卷积层花费的时间就越长。由于全连接层在目前的许多网络中都不存在了。所以,轻量化网络很大程度上是对卷积层的优化。
下图是各个网络的 计算量—准确率 散点图。我们希望网络在满足更少的计算量的同时也能保证较高的准确率,也就是下图越靠近左上角的网络越好。我们今天要讨论的MobileNet系列网络,计算量在3-4百万个参数,准确率适中。
MobileNetV1 的核心是使用了深度可分离卷积,接下来就详细讲一讲这个理念。
首先,传统的卷积是:一个多通道的卷积核在多通道的输入图像上滑动,把每次滑动位置的卷积核的权重和原始输入图像的对应像素相乘再求和,将计算结果填在新生成的特征图对应像素位置上。输入有几个通道,卷积核就有几个通道。卷积核在图像上不断滑动,滑动的区域称为感受野。如下图所示。
举个例子来说,若输入的图像shape为5x5x3,一个卷积核的shape为3x3x3,得到的特征图shape为3x3x1
深度可分离卷积(depthwise separable convolution)可理解为:由 深度卷积(depthwise convolution)和 逐点卷积(pointwise convolution)构成。
以下图为例,深度卷积可以理解为:输入三通道的图像,使用三个卷积核,每个卷积核只有一个通道,分别对输入图像的三个通道的特征进行卷积。每个通道都用自己对应的卷积核生成一张对应的特征图,因此输入是三通道输出也是三通道。
举个例子:输入图像shape为5x5x3,在通道维度上拆分为3个5x5x1,使用3个3x3x1的卷积核,每个卷积核负责一个通道。将卷积的结果叠加起来,输出shape为3x3x3
逐点卷积用于处理跨通道信息(跨层信息融合),采用的是标准卷积的方法,只不过卷积核的size是1*1大小。如下图所示,输入特征图的shape为5*5*3,使用一个1*1*3的卷积核。滑动过程中,卷积核权重和输入图像的对应像素值相乘再相加,得到输出特征图的shape为5*5*1
深度卷积每个卷积核只关心自己通道的信息,没有考虑跨通道的信息,现在跨通道的信息由1*1卷积来补充。下图1个1*1的卷积核生成了1张特征图,那么n个1*1的卷积核就生成n张特征图
先进行深度卷积再进行逐点卷积。深度卷积处理长宽方向的空间信息,不关心跨通道信息。逐点卷积只关心跨通道信息,不关心长宽方向的信息,因为它的size只有1*1
网络由两个模块构成,标准卷积模块和深度可分离卷积卷积模块,先将这两个模块定义好,下面能直接调用。
模型涉及两个超参数。alpha:宽度超参数,控制卷积核个数; depth_multiplier:分辨率超参数,控制输入图像的尺寸,进而控制中间层特征图的大小。
所有层的 通道数 乘以 alpha 参数(四舍五入),模型大小近似下降到原来的 alpha^2 倍,计算量下降到原来的 alpha^2 倍,用于降低模型的宽度。
输入层的 分辨率 乘以 depth_multiplier 参数 (四舍五入),等价于所有层的分辨率乘 depth_multiplier,模型大小不变,计算量下降到原来的 depth_multiplier^2 倍,用于降低输入图像的分辨率。
2.1 标准卷积模块
标准卷积由 卷积+批标准化+激活函数 构成。
这里使用 ReLU6 激活函数。主要是为了在移动端计算时,float16的低精度的时候,也能有很好的数值分辨率,如果对reLu的输出值不加限制,那么输出范围就是0到正无穷,而低精度的float16无法精确描述其数值,带来精度损失。
#(1)标准卷积模块
def conv_block(input_tensor, filters, alpha, kernel_size=(3,3), strides=(1,1)):
# 超参数alpha控制卷积核个数
filters = int(filters*alpha)
# 卷积+批标准化+激活函数
x = layers.Conv2D(filters, kernel_size,
strides=strides, # 步长
padding='same', # 0填充,卷积后特征图size不变
use_bias=False)(input_tensor) # 有BN层就不需要计算偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU(6.0)(x) # relu6激活函数
return x # 返回一次标准卷积后的结果
relu6 函数和 relu 函数图如下
如果卷积层之后跟了BatchNormalization层,可以不用再加偏置了use_bias=False。如果加了,对模型不起作用,还会占用内存。
#(2)深度可分离卷积块
def depthwise_conv_block(input_tensor, point_filters, alpha, depth_multiplier, strides=(1,1)):
# 超参数alpha控制逐点卷积的卷积核个数
point_filters = int(point_filters*alpha)
# ① 深度卷积--输出特征图个数和输入特征图的通道数相同
x = layers.DepthwiseConv2D(kernel_size=(3,3), # 卷积核size默认3*3
strides=strides, # 步长
padding='same', # strides=1时,卷积过程中特征图size不变
depth_multiplier=depth_multiplier, # 超参数,控制卷积层中间输出特征图的长宽
use_bias=False)(input_tensor) # 有BN层就不需要偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU(6.0)(x) # relu6激活函数
# ② 逐点卷积--1*1标准卷积
x = layers.Conv2D(point_filters, kernel_size=(1,1), # 卷积核默认1*1
padding='same', # 卷积过程中特征图size不变
strides=(1,1), # 步长为1,对特征图上每个像素点卷积
use_bias=False)(x) # 有BN层,不需要偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU(6.0)(x) # 激活函数
return x # 返回深度可分离卷积结果
根据论文中的网络模型架构,堆叠每一层。下表中的 Conv dw 为深度卷积,Conv / s1 是逐点卷积
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model, optimizers
#(1)标准卷积模块
def conv_block(input_tensor, filters, alpha, kernel_size=(3,3), strides=(1,1)):
# 超参数alpha控制卷积核个数
filters = int(filters*alpha)
# 卷积+批标准化+激活函数
x = layers.Conv2D(filters, kernel_size,
strides=strides, # 步长
padding='same', # 0填充,卷积后特征图size不变
use_bias=False)(input_tensor) # 有BN层就不需要计算偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU(6.0)(x) # relu6激活函数
return x # 返回一次标准卷积后的结果
#(2)深度可分离卷积块
def depthwise_conv_block(input_tensor, point_filters, alpha, depth_multiplier, strides=(1,1)):
# 超参数alpha控制逐点卷积的卷积核个数
point_filters = int(point_filters*alpha)
# ① 深度卷积--输出特征图个数和输入特征图的通道数相同
x = layers.DepthwiseConv2D(kernel_size=(3,3), # 卷积核size默认3*3
strides=strides, # 步长
padding='same', # strides=1时,卷积过程中特征图size不变
depth_multiplier=depth_multiplier, # 超参数,控制卷积层中间输出特征图的长宽
use_bias=False)(input_tensor) # 有BN层就不需要偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU(6.0)(x) # relu6激活函数
# ② 逐点卷积--1*1标准卷积
x = layers.Conv2D(point_filters, kernel_size=(1,1), # 卷积核默认1*1
padding='same', # 卷积过程中特征图size不变
strides=(1,1), # 步长为1,对特征图上每个像素点卷积
use_bias=False)(x) # 有BN层,不需要偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU(6.0)(x) # 激活函数
return x # 返回深度可分离卷积结果
#(3)主干网络
def MobileNet(classes, input_shape, alpha, depth_multiplier, dropout_rate):
# 创建输入层
inputs = layers.Input(shape=input_shape) # [224,224,3]
# [224,224,3]==>[112,112,32]
x = conv_block(inputs, 32, alpha, strides=(2,2)) # 步长为2,压缩宽高,提升通道数
# [112,112,32]==>[112,112,64]
x = depthwise_conv_block(x, 64, alpha, depth_multiplier) # 深度可分离卷积。逐点卷积时卷积核个数为64
# [112,112,64]==>[56,56,128]
x = depthwise_conv_block(x, 128, alpha, depth_multiplier, strides=(2,2)) # 步长为2,压缩特征图size
# [56,56,128]==>[56,56,128]
x = depthwise_conv_block(x, 128, alpha, depth_multiplier)
# [56,56,128]==>[28,28,256]
x = depthwise_conv_block(x, 256, alpha, depth_multiplier, strides=(2,2))
# [28,28,256]==>[28,28,256]
x = depthwise_conv_block(x, 256, alpha, depth_multiplier)
# [28,28,256]==>[14,14,512]
x = depthwise_conv_block(x, 512, alpha, depth_multiplier, strides=(2,2))
# [14,14,512]==>[14,14,512]
x = depthwise_conv_block(x, 512, alpha, depth_multiplier)
x = depthwise_conv_block(x, 512, alpha, depth_multiplier)
x = depthwise_conv_block(x, 512, alpha, depth_multiplier)
x = depthwise_conv_block(x, 512, alpha, depth_multiplier)
x = depthwise_conv_block(x, 512, alpha, depth_multiplier)
# [14,14,512]==>[7,7,1024]
x = depthwise_conv_block(x, 1024, alpha, depth_multiplier, strides=(2,2))
# [7,7,1024]==>[7,7,1024]
x = depthwise_conv_block(x, 1024, alpha, depth_multiplier)
# [7,7,1024]==>[1,1,1024] 全局平均池化
x = layers.GlobalAveragePooling2D()(x) # 通道维度上对size维度求平均
# 超参数调整卷积核(特征图)个数
shape = (1, 1, int(1024 * alpha))
# 调整输出特征图x的特征图个数
x = layers.Reshape(target_shape=shape)(x)
# Dropout层随机杀死神经元,防止过拟合
x = layers.Dropout(rate=dropout_rate)(x)
# 卷积层,将特征图x的个数转换成分类数
x = layers.Conv2D(classes, kernel_size=(1,1), padding='same')(x)
# 经过softmax函数,变成分类概率
x = layers.Activation('softmax')(x)
# 重塑概率数排列形式
x = layers.Reshape(target_shape=(classes,))(x)
# 构建模型
model = Model(inputs, x)
# 返回模型结构
return model
if __name__ == '__main__':
# 获得模型结构
model = MobileNet(classes=1000, # 分类种类数
input_shape=[224,224,3], # 模型输入图像shape
alpha=1.0, # 超参数,控制卷积核个数
depth_multiplier=1, # 超参数,控制图像分辨率
dropout_rate=1e-3) # 随即杀死神经元的概率
# 查看网络模型结构
model.summary()
打印网络模型结构如下,可见参数量是四百万,相比于VGG网络的一亿参数量已经是非常轻量化的了。
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 112, 112, 32) 864
_________________________________________________________________
batch_normalization (BatchNo (None, 112, 112, 32) 128
_________________________________________________________________
re_lu (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 112, 112, 32) 288
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 32) 128
_________________________________________________________________
re_lu_1 (ReLU) (None, 112, 112, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 112, 112, 64) 2048
_________________________________________________________________
batch_normalization_2 (Batch (None, 112, 112, 64) 256
_________________________________________________________________
re_lu_2 (ReLU) (None, 112, 112, 64) 0
_________________________________________________________________
depthwise_conv2d_1 (Depthwis (None, 56, 56, 64) 576
_________________________________________________________________
batch_normalization_3 (Batch (None, 56, 56, 64) 256
_________________________________________________________________
re_lu_3 (ReLU) (None, 56, 56, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 56, 56, 128) 8192
_________________________________________________________________
batch_normalization_4 (Batch (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_4 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
depthwise_conv2d_2 (Depthwis (None, 56, 56, 128) 1152
_________________________________________________________________
batch_normalization_5 (Batch (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_5 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 56, 56, 128) 16384
_________________________________________________________________
batch_normalization_6 (Batch (None, 56, 56, 128) 512
_________________________________________________________________
re_lu_6 (ReLU) (None, 56, 56, 128) 0
_________________________________________________________________
depthwise_conv2d_3 (Depthwis (None, 28, 28, 128) 1152
_________________________________________________________________
batch_normalization_7 (Batch (None, 28, 28, 128) 512
_________________________________________________________________
re_lu_7 (ReLU) (None, 28, 28, 128) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 28, 28, 256) 32768
_________________________________________________________________
batch_normalization_8 (Batch (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_8 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
depthwise_conv2d_4 (Depthwis (None, 28, 28, 256) 2304
_________________________________________________________________
batch_normalization_9 (Batch (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_9 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 28, 28, 256) 65536
_________________________________________________________________
batch_normalization_10 (Batc (None, 28, 28, 256) 1024
_________________________________________________________________
re_lu_10 (ReLU) (None, 28, 28, 256) 0
_________________________________________________________________
depthwise_conv2d_5 (Depthwis (None, 14, 14, 256) 2304
_________________________________________________________________
batch_normalization_11 (Batc (None, 14, 14, 256) 1024
_________________________________________________________________
re_lu_11 (ReLU) (None, 14, 14, 256) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 14, 14, 512) 131072
_________________________________________________________________
batch_normalization_12 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_12 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_6 (Depthwis (None, 14, 14, 512) 4608
_________________________________________________________________
batch_normalization_13 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_13 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
batch_normalization_14 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_14 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_7 (Depthwis (None, 14, 14, 512) 4608
_________________________________________________________________
batch_normalization_15 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_15 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
batch_normalization_16 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_16 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_8 (Depthwis (None, 14, 14, 512) 4608
_________________________________________________________________
batch_normalization_17 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_17 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
batch_normalization_18 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_18 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_9 (Depthwis (None, 14, 14, 512) 4608
_________________________________________________________________
batch_normalization_19 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_19 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
batch_normalization_20 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_20 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_10 (Depthwi (None, 14, 14, 512) 4608
_________________________________________________________________
batch_normalization_21 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_21 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 14, 14, 512) 262144
_________________________________________________________________
batch_normalization_22 (Batc (None, 14, 14, 512) 2048
_________________________________________________________________
re_lu_22 (ReLU) (None, 14, 14, 512) 0
_________________________________________________________________
depthwise_conv2d_11 (Depthwi (None, 7, 7, 512) 4608
_________________________________________________________________
batch_normalization_23 (Batc (None, 7, 7, 512) 2048
_________________________________________________________________
re_lu_23 (ReLU) (None, 7, 7, 512) 0
_________________________________________________________________
conv2d_12 (Conv2D) (None, 7, 7, 1024) 524288
_________________________________________________________________
batch_normalization_24 (Batc (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_24 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
depthwise_conv2d_12 (Depthwi (None, 7, 7, 1024) 9216
_________________________________________________________________
batch_normalization_25 (Batc (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_25 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
conv2d_13 (Conv2D) (None, 7, 7, 1024) 1048576
_________________________________________________________________
batch_normalization_26 (Batc (None, 7, 7, 1024) 4096
_________________________________________________________________
re_lu_26 (ReLU) (None, 7, 7, 1024) 0
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024) 0
_________________________________________________________________
reshape (Reshape) (None, 1, 1, 1024) 0
_________________________________________________________________
dropout (Dropout) (None, 1, 1, 1024) 0
_________________________________________________________________
conv2d_14 (Conv2D) (None, 1, 1, 1000) 1025000
_________________________________________________________________
activation (Activation) (None, 1, 1, 1000) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 1000) 0
=================================================================
Total params: 4,253,864
Trainable params: 4,231,976
Non-trainable params: 21,888
_________________________________________________________________