tensorflow2.2实现ResNeXt

1. 分组卷积

    分组卷积(Group Convolution)最早出现在AlexNet中。 受限于当时的硬件资源,在AlexNet网络训练时,难以把整个网络全部放在一个GPU中进行训练,因此,作者将卷积运算分给多个GPU分别进行计算,最终把多个GPU的结果进行融合。 因此分组卷积的概念应运而生。
    分组卷积简单来说就是将每层的特征图数量分为不同的组,然后对不同组的特征图进行卷积操作。
    分组卷积的优点:帮助模型减少了计算量和权值参数
    如下图是论文《Aggregated Residual Transformations for Deep Neural Networks》中的实验结果。

  • params代表参数量
  • FLOPs代表计算量
    tensorflow2.2实现ResNeXt_第1张图片

2. ResNeXt中的分组卷积

    如下图右边是论文中使用的残差结构。
tensorflow2.2实现ResNeXt_第2张图片
其中:

  • 256-d in代表的是输入256个特征图
  • 256-d out代表的是输出256个特征图
  • total 32 paths代表的是一个有32个通道
        左边的图是普通的残差结构,右边是加了分组卷积的残差结构。
  1. 第一层输入256个特征图,输出了128个特征图,然后进行分组,32组,每组4个特征图。
  2. 第二层对每组分别进行卷积计算,卷积之后,每组输出4个特征图。
  3. 第三层对4个特征图进行卷积,输出256个特征图。
  4. 第四层对32组卷积后的特征图进行堆叠。

3. ResNeXt的网络结构

如下图
tensorflow2.2实现ResNeXt_第3张图片
其中:

  • stage表示阶段
  • output表示输出
  • stride表示步长

4. 实现代码

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (Dense, ZeroPadding2D, Conv2D, MaxPool2D, 
                                     GlobalAvgPool2D, Input, BatchNormalization,
                                     Activation, Add, Lambda, concatenate)
from tensorflow.keras.models import Model
from plot_model import plot_model


# ----------------------- #
#   groups代表多少组
#   g_channels代表每组的特征图数量
# ----------------------- #
def group_conv2_block(x_0, strides, groups, g_channels):
    g_list = []
    for i in range(groups):
        x = Lambda(lambda x: x[:, :, :, i*g_channels: (i+1)*g_channels])(x_0)
        x = Conv2D(filters=g_channels, kernel_size=3, strides=strides, padding='same', use_bias=False)(x)
        g_list.append(x)
    x = concatenate(g_list, axis=3)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    return x

# 结构快
def block(x, filters, strides=1, groups=32, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*2, kernel_size=1, strides=strides, padding='same')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    
    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    
    g_channels = int(filters/groups)
    x = group_conv2_block(x, strides=strides, groups=groups, g_channels=g_channels)

    x = Conv2D(filters=filters*2, kernel_size=1, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

    
def Resnext(inputs, classes):
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=7, strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='valid')(x)

    x = block(x, filters=128, strides=1, conv_short=True)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)
    
    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    
    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)

    x = block(x, filters=1024, strides=2, conv_short=True)
    x = block(x, filters=1024, conv_short=False)
    x = block(x, filters=1024, conv_short=False)

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    
    return x

if __name__ == '__main__':

    is_show_picture = False
    inputs = Input(shape=(224,224,3))
    classes = 17
    model = Model(inputs=inputs, outputs=Resnext(inputs, classes))
    model.summary()
    for i in range(len(model.layers)):
        print(i, model.layers[i])
    if is_show_picture:
        plot_model(model,
           to_file='./nets_picture/Resnext.png',
           )
        print("plot_model------------------------>")
    
    
    

你可能感兴趣的:(深度学习,深度学习,cnn,计算机视觉)