各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 Xception 神经网络模型。
在前面章节中,我已经介绍了很多种轻量化卷积神经网络模型,感兴趣的可以看一下:https://blog.csdn.net/dgvv4/category_11517910.html
Xception 是一种兼顾了准确性和轻量化的算法。如下图所示,横轴表示计算量,纵轴表示准确率。在准确率上,Xception是排在第一梯队的,且在计算速度上,也算是轻量化网络模型。
Xception 使用了 MobileNetV1 的深度可分离卷积方法,建议大家先学习一下 MobileNetV1:https://blog.csdn.net/dgvv4/article/details/123415708
为了帮助大家更好地掌握 Xception,先简单地复习一下深度可分离卷积的方法。
普通卷积是一个卷积核处理所有的通道,输入特征图有多少个通道,卷积核就有几个通道,一个卷积核生成一张特征图。
深度可分离卷积 可理解为 深度卷积 + 逐点卷积
深度卷积只处理长宽方向的空间信息;逐点卷积只处理跨通道方向的信息。能大大减少参数量,提高计算效率
深度卷积: 是一个卷积核只处理一个通道,即每个卷积核只处理自己对应的通道。输入特征图有多少个通道就有多少个卷积核。将每个卷积核处理后的特征图堆叠在一起。输入和输出特征图的通道数相同。
由于只处理长宽方向的信息会导致丢失跨通道信息,为了将跨通道的信息补充回来,需要进行逐点卷积。
逐点卷积: 是使用1x1卷积对跨通道维度处理,有多少个1x1卷积核就会生成多少个特征图。
接下来梳理一下从Inception到Xception网络的核心模块的改进过程,帮助大家对Xception结构有进一步的认识。
首先 InceptionV1 是由9个 BottleNeck-Inception 模块堆叠而成,如下图所示。
Inception模块的原理: 将输入的特征图分成四个分支,进行四种不同的处理,再将四种方法处理的结果特征图堆叠起来,输入到下一层。
通过尽可能多的分解和解耦,用不同的尺度、不同的卷积来获取不同层次,不同力度的信息。
随着 Inception 模块的输出特征图不断的堆叠,特征图的通道数会越来越多。为了防止特征图越来越多,运算量和参数量爆炸。在 3x3 和 5x5 卷积之前添加了1x1卷积进行降维,控制输出特征图的数量,减少参数量和计算量。左图为Inception模块,右图为BottleNeck模块。
(1)首先 InceptionV3 改进了 BottleNeck 模块,将 5x5 卷积分解成两个 3x3 卷积。两层3x3卷积代替一层5x5卷积,可以获得相同的感受野,减少参数量,增加非线性,提高模型的表达能力。
(2)将池化层后的1x1卷积换成3x3卷积。
(3)第一层全使用1x1卷积,第二层全使用3x3卷积。
(4)图像输入进来后,先经过一次1x1卷积生成特征图,接下来三个分支都对这个特征图处理。
(5)图像输入后,使用分组卷积对1x1卷积后的特征图处理,不同的卷积核处理不同的通道,各分组之间相互独立。
(6)Xception模块,使用深度可分离卷积思想,先逐点卷积,后深度卷积,每个3x3卷积只处理一个通道。逐点卷积和深度卷积的先后次序并太大无影响。
论文中给出的 Xception 网络模型结构如下图所示
(1)标准卷积块
一个标准卷积块由 卷积+批标准化+激活函数 组成
#(1)标准卷积模块
def conv_block(input_tensor, filters, kernel_size, stride):
# 普通卷积+标准化+激活函数
x = layers.Conv2D(filters = filters, # 输出特征图个数
kernel_size = kernel_size, # 卷积size
strides = stride, # 步长
padding = 'same', # 步长=1输出特征图size不变,步长=2特征图长宽减半
use_bias = False)(input_tensor) # 有BN层就不需要偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU()(x) # relu激活函数
return x # 返回标准卷积的输出特征图
(2)残差块
按结构图所示,构建一个残差单元,由 两个深度可分离卷积+最大池化+残差边 组成
#(2)深度可分离卷积模块
def sep_conv_block(input_tensor, filters, kernel_size):
# 激活函数
x = layers.ReLU()(input_tensor)
# 深度可分离卷积函数,包含了(深度卷积+逐点卷积)
x = layers.SeparableConvolution2D(filters = filters, # 逐点卷积的卷积核个数,输出特征图个数
kernel_size = kernel_size, # 深度卷积的卷积核size
strides = 1, # 深度卷积的步长
padding = 'same', # 卷积过程中特征图size不变
use_bias = False)(x) # 有BN层就不要偏置
return x # 返回输出特征图
#(3)一个残差单元
def res_block(input_tensor, filters):
# ① 残差边
residual = layers.Conv2D(filters, # 输出图像的通道数
kernel_size = (1,1), # 卷积核size
strides = 2)(input_tensor) # 使输入和输出的size相同
residual = layers.BatchNormalization()(residual) # 批标准化
# ② 卷积块
x = sep_conv_block(input_tensor, filters, kernel_size=(3,3))
x = sep_conv_block(x, filters, kernel_size=(3,3))
x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)
# ③ 输入输出叠加,残差连接
output = layers.Add()([residual, x])
return output
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, layers
#(1)标准卷积模块
def conv_block(input_tensor, filters, kernel_size, stride):
# 普通卷积+标准化+激活函数
x = layers.Conv2D(filters = filters, # 输出特征图个数
kernel_size = kernel_size, # 卷积size
strides = stride, # 步长
padding = 'same', # 步长=1输出特征图size不变,步长=2特征图长宽减半
use_bias = False)(input_tensor) # 有BN层就不需要偏置
x = layers.BatchNormalization()(x) # 批标准化
x = layers.ReLU()(x) # relu激活函数
return x # 返回标准卷积的输出特征图
#(2)深度可分离卷积模块
def sep_conv_block(input_tensor, filters, kernel_size):
# 激活函数
x = layers.ReLU()(input_tensor)
# 深度可分离卷积函数,包含了(深度卷积+逐点卷积)
x = layers.SeparableConvolution2D(filters = filters, # 逐点卷积的卷积核个数,输出特征图个数
kernel_size = kernel_size, # 深度卷积的卷积核size
strides = 1, # 深度卷积的步长
padding = 'same', # 卷积过程中特征图size不变
use_bias = False)(x) # 有BN层就不要偏置
return x # 返回输出特征图
#(3)一个残差单元
def res_block(input_tensor, filters):
# ① 残差边
residual = layers.Conv2D(filters, # 输出图像的通道数
kernel_size = (1,1), # 卷积核size
strides = 2)(input_tensor) # 使输入和输出的size相同
residual = layers.BatchNormalization()(residual) # 批标准化
# ② 卷积块
x = sep_conv_block(input_tensor, filters, kernel_size=(3,3))
x = sep_conv_block(x, filters, kernel_size=(3,3))
x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)
# ③ 输入输出叠加,残差连接
output = layers.Add()([residual, x])
return output
#(4)Middle Flow模块
def middle_flow(x, filters):
# 该模块循环8次
for _ in range(8):
# 残差边
residual = x
# 三个深度可分离卷积块
x = sep_conv_block(x, filters, kernel_size=(3,3))
x = sep_conv_block(x, filters, kernel_size=(3,3))
x = sep_conv_block(x, filters, kernel_size=(3,3))
# 叠加残差边
x = layers.Add()([residual, x])
return x
#(5)主干网络
def xception(input_shape, classes):
# 构建输入
inputs = keras.Input(shape=input_shape)
# [299,299,3]==>[149,149,32]
x = conv_block(inputs, filters=32, kernel_size=(3,3), stride=2) # 标准卷积块
# [149,149,32]==>[149,149,64]
x = conv_block(x, filters=64, kernel_size=(3,3), stride=1)
# [149,149,64]==>[75,75,128]
# 残差边
residual = layers.Conv2D(filters=128, kernel_size=(1,1), strides=2,
padding='same', use_bias=False)(x)
residual = layers.BatchNormalization()(residual)
# 卷积块[149,149,64]==>[149,149,128]
x = layers.SeparableConv2D(128, kernel_size=(3,3), strides=1, padding='same',use_bias=False)(x)
x = layers.BatchNormalization()(x)
# [149,149,128]==>[149,149,128]
x = sep_conv_block(x, filters=128, kernel_size=(3,3))
# [149,149,128]==>[75,75,128]
x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)
# [75,75,128]==>[38,38,256]
x = res_block(x, filters=256)
# [38,38,256]==>[19,19,728]
x = res_block(x, filters=728)
# [19,19,728]==>[19,19,728]
x = middle_flow(x, filters=728)
# 残差边模块[19,19,728]==>[10,10,1024]
residual = layers.Conv2D(filters=1024, kernel_size=(1,1),
strides=2, use_bias=False, padding='same')(x)
residual = layers.BatchNormalization()(residual) # 批标准化
# 卷积块[19,19,728]==>[19,19,728]
x = sep_conv_block(x, filters=728, kernel_size=(3,3))
# [19,19,728]==>[19,19,1024]
x = sep_conv_block(x, filters=1024, kernel_size=(3,3))
# [19,19,1024]==>[10,10,1024]
x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)
# 叠加残差边[10,10,1024]
x = layers.Add()([residual, x])
# [10,10,1024]==>[10,10,1536]
x = layers.SeparableConv2D(1536, (3,3), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
# [10,10,1536]==>[10,10,2048]
x = layers.SeparableConv2D(2048, (3,3), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
# [10,10,2048]==>[None,2048]
x = layers.GlobalAveragePooling2D()(x)
# [None,2048]==>[None,classes]
outputs = layers.Dense(classes)(x) # logits层不做softmax
# 构建模型
model = Model(inputs, outputs)
return model
#(6)接收网络模型
if __name__ == '__main__':
model = xception(input_shape=[299,299,3], classes=1000)
model.summary() # 查看网络模型结构
通过 model.summary() 查看网络模型框架,网络参数量2千多万
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 299, 299, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 150, 150, 32) 864 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 150, 150, 32) 128 conv2d[0][0]
__________________________________________________________________________________________________
re_lu (ReLU) (None, 150, 150, 32) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 150, 150, 64) 18432 re_lu[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 150, 150, 64) 256 conv2d_1[0][0]
__________________________________________________________________________________________________
re_lu_1 (ReLU) (None, 150, 150, 64) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
separable_conv2d (SeparableConv (None, 150, 150, 128 8768 re_lu_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 150, 150, 128 512 separable_conv2d[0][0]
__________________________________________________________________________________________________
re_lu_2 (ReLU) (None, 150, 150, 128 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
separable_conv2d_1 (SeparableCo (None, 150, 150, 128 17536 re_lu_2[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 75, 75, 128) 0 separable_conv2d_1[0][0]
__________________________________________________________________________________________________
re_lu_3 (ReLU) (None, 75, 75, 128) 0 max_pooling2d[0][0]
__________________________________________________________________________________________________
separable_conv2d_2 (SeparableCo (None, 75, 75, 256) 33920 re_lu_3[0][0]
__________________________________________________________________________________________________
re_lu_4 (ReLU) (None, 75, 75, 256) 0 separable_conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 38, 38, 256) 33024 max_pooling2d[0][0]
__________________________________________________________________________________________________
separable_conv2d_3 (SeparableCo (None, 75, 75, 256) 67840 re_lu_4[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 38, 38, 256) 1024 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 38, 38, 256) 0 separable_conv2d_3[0][0]
__________________________________________________________________________________________________
add (Add) (None, 38, 38, 256) 0 batch_normalization_4[0][0]
max_pooling2d_1[0][0]
__________________________________________________________________________________________________
re_lu_5 (ReLU) (None, 38, 38, 256) 0 add[0][0]
__________________________________________________________________________________________________
separable_conv2d_4 (SeparableCo (None, 38, 38, 728) 188672 re_lu_5[0][0]
__________________________________________________________________________________________________
re_lu_6 (ReLU) (None, 38, 38, 728) 0 separable_conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 19, 19, 728) 187096 add[0][0]
__________________________________________________________________________________________________
separable_conv2d_5 (SeparableCo (None, 38, 38, 728) 536536 re_lu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 19, 19, 728) 2912 conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 19, 19, 728) 0 separable_conv2d_5[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 19, 19, 728) 0 batch_normalization_5[0][0]
max_pooling2d_2[0][0]
__________________________________________________________________________________________________
re_lu_7 (ReLU) (None, 19, 19, 728) 0 add_1[0][0]
__________________________________________________________________________________________________
separable_conv2d_6 (SeparableCo (None, 19, 19, 728) 536536 re_lu_7[0][0]
__________________________________________________________________________________________________
re_lu_8 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_6[0][0]
__________________________________________________________________________________________________
separable_conv2d_7 (SeparableCo (None, 19, 19, 728) 536536 re_lu_8[0][0]
__________________________________________________________________________________________________
re_lu_9 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_7[0][0]
__________________________________________________________________________________________________
separable_conv2d_8 (SeparableCo (None, 19, 19, 728) 536536 re_lu_9[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 19, 19, 728) 0 add_1[0][0]
separable_conv2d_8[0][0]
__________________________________________________________________________________________________
re_lu_10 (ReLU) (None, 19, 19, 728) 0 add_2[0][0]
__________________________________________________________________________________________________
separable_conv2d_9 (SeparableCo (None, 19, 19, 728) 536536 re_lu_10[0][0]
__________________________________________________________________________________________________
re_lu_11 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_9[0][0]
__________________________________________________________________________________________________
separable_conv2d_10 (SeparableC (None, 19, 19, 728) 536536 re_lu_11[0][0]
__________________________________________________________________________________________________
re_lu_12 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_10[0][0]
__________________________________________________________________________________________________
separable_conv2d_11 (SeparableC (None, 19, 19, 728) 536536 re_lu_12[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 19, 19, 728) 0 add_2[0][0]
separable_conv2d_11[0][0]
__________________________________________________________________________________________________
re_lu_13 (ReLU) (None, 19, 19, 728) 0 add_3[0][0]
__________________________________________________________________________________________________
separable_conv2d_12 (SeparableC (None, 19, 19, 728) 536536 re_lu_13[0][0]
__________________________________________________________________________________________________
re_lu_14 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_12[0][0]
__________________________________________________________________________________________________
separable_conv2d_13 (SeparableC (None, 19, 19, 728) 536536 re_lu_14[0][0]
__________________________________________________________________________________________________
re_lu_15 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_13[0][0]
__________________________________________________________________________________________________
separable_conv2d_14 (SeparableC (None, 19, 19, 728) 536536 re_lu_15[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 19, 19, 728) 0 add_3[0][0]
separable_conv2d_14[0][0]
__________________________________________________________________________________________________
re_lu_16 (ReLU) (None, 19, 19, 728) 0 add_4[0][0]
__________________________________________________________________________________________________
separable_conv2d_15 (SeparableC (None, 19, 19, 728) 536536 re_lu_16[0][0]
__________________________________________________________________________________________________
re_lu_17 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_15[0][0]
__________________________________________________________________________________________________
separable_conv2d_16 (SeparableC (None, 19, 19, 728) 536536 re_lu_17[0][0]
__________________________________________________________________________________________________
re_lu_18 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_16[0][0]
__________________________________________________________________________________________________
separable_conv2d_17 (SeparableC (None, 19, 19, 728) 536536 re_lu_18[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 19, 19, 728) 0 add_4[0][0]
separable_conv2d_17[0][0]
__________________________________________________________________________________________________
re_lu_19 (ReLU) (None, 19, 19, 728) 0 add_5[0][0]
__________________________________________________________________________________________________
separable_conv2d_18 (SeparableC (None, 19, 19, 728) 536536 re_lu_19[0][0]
__________________________________________________________________________________________________
re_lu_20 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_18[0][0]
__________________________________________________________________________________________________
separable_conv2d_19 (SeparableC (None, 19, 19, 728) 536536 re_lu_20[0][0]
__________________________________________________________________________________________________
re_lu_21 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_19[0][0]
__________________________________________________________________________________________________
separable_conv2d_20 (SeparableC (None, 19, 19, 728) 536536 re_lu_21[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 19, 19, 728) 0 add_5[0][0]
separable_conv2d_20[0][0]
__________________________________________________________________________________________________
re_lu_22 (ReLU) (None, 19, 19, 728) 0 add_6[0][0]
__________________________________________________________________________________________________
separable_conv2d_21 (SeparableC (None, 19, 19, 728) 536536 re_lu_22[0][0]
__________________________________________________________________________________________________
re_lu_23 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_21[0][0]
__________________________________________________________________________________________________
separable_conv2d_22 (SeparableC (None, 19, 19, 728) 536536 re_lu_23[0][0]
__________________________________________________________________________________________________
re_lu_24 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_22[0][0]
__________________________________________________________________________________________________
separable_conv2d_23 (SeparableC (None, 19, 19, 728) 536536 re_lu_24[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 19, 19, 728) 0 add_6[0][0]
separable_conv2d_23[0][0]
__________________________________________________________________________________________________
re_lu_25 (ReLU) (None, 19, 19, 728) 0 add_7[0][0]
__________________________________________________________________________________________________
separable_conv2d_24 (SeparableC (None, 19, 19, 728) 536536 re_lu_25[0][0]
__________________________________________________________________________________________________
re_lu_26 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_24[0][0]
__________________________________________________________________________________________________
separable_conv2d_25 (SeparableC (None, 19, 19, 728) 536536 re_lu_26[0][0]
__________________________________________________________________________________________________
re_lu_27 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_25[0][0]
__________________________________________________________________________________________________
separable_conv2d_26 (SeparableC (None, 19, 19, 728) 536536 re_lu_27[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 19, 19, 728) 0 add_7[0][0]
separable_conv2d_26[0][0]
__________________________________________________________________________________________________
re_lu_28 (ReLU) (None, 19, 19, 728) 0 add_8[0][0]
__________________________________________________________________________________________________
separable_conv2d_27 (SeparableC (None, 19, 19, 728) 536536 re_lu_28[0][0]
__________________________________________________________________________________________________
re_lu_29 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_27[0][0]
__________________________________________________________________________________________________
separable_conv2d_28 (SeparableC (None, 19, 19, 728) 536536 re_lu_29[0][0]
__________________________________________________________________________________________________
re_lu_30 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_28[0][0]
__________________________________________________________________________________________________
separable_conv2d_29 (SeparableC (None, 19, 19, 728) 536536 re_lu_30[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 19, 19, 728) 0 add_8[0][0]
separable_conv2d_29[0][0]
__________________________________________________________________________________________________
re_lu_31 (ReLU) (None, 19, 19, 728) 0 add_9[0][0]
__________________________________________________________________________________________________
separable_conv2d_30 (SeparableC (None, 19, 19, 728) 536536 re_lu_31[0][0]
__________________________________________________________________________________________________
re_lu_32 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_30[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 10, 10, 1024) 745472 add_9[0][0]
__________________________________________________________________________________________________
separable_conv2d_31 (SeparableC (None, 19, 19, 1024) 752024 re_lu_32[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 10, 10, 1024) 4096 conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 10, 10, 1024) 0 separable_conv2d_31[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 10, 10, 1024) 0 batch_normalization_6[0][0]
max_pooling2d_3[0][0]
__________________________________________________________________________________________________
separable_conv2d_32 (SeparableC (None, 10, 10, 1536) 1582080 add_10[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 10, 10, 1536) 6144 separable_conv2d_32[0][0]
__________________________________________________________________________________________________
re_lu_33 (ReLU) (None, 10, 10, 1536) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
separable_conv2d_33 (SeparableC (None, 10, 10, 2048) 3159552 re_lu_33[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 10, 10, 2048) 8192 separable_conv2d_33[0][0]
__________________________________________________________________________________________________
re_lu_34 (ReLU) (None, 10, 10, 2048) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 2048) 0 re_lu_34[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 1000) 2049000 global_average_pooling2d[0][0]
==================================================================================================
Total params: 22,817,480
Trainable params: 22,805,848
Non-trainable params: 11,632
__________________________________________________________________________________________________