Tensorflow2.4实现RepVGG

前言

RepVGG是清华大学&旷视科技等提出的一种新颖的CNN设计范式,避免了VGG类方法训练所得精度低的问题,又保持了VGG方案的高效推理优点。

论文地址: https://arxiv.org/abs/2101.03697

博客拿出我觉得重要的两点方法。

训练多分支结构

ResNet结构中ResBlock使用的是 y = x + f ( x ) y=x+f(x) y=x+f(x),尽管多分支结构对于推理不友好,但对于训练友好,作者将RepVGG设计为训练时的多分支,推理时单分支结构。作者参考ResNet的identity与 1 × 1 1\times1 1×1分支,设计了如下形式模块:

y = x + g ( x ) + f ( x ) y=x+g(x)+f(x) y=x+g(x)+f(x)

其中:

  • g ( x ) g(x) g(x)为1x1卷积。
  • f ( x ) f(x) f(x)为3x3卷积。

简单是快,内存经济型,灵活

  • Fast:相比VGG,现有的多分支架构理论上具有更低的Flops,但推理速度并未更快。比如VGG16的参数量为EfficientNetB3的8.4倍,但在1080Ti上推理速度反而快1.8倍。这就意味着前者的计算密度是后者的15倍。Flops与推理速度的矛盾主要源自两个关键因素:(1) MAC(memory access cose),比如多分支结构的Add与Cat的计算很小,但MAC很高; (2)并行度,已有研究表明:并行度高的模型要比并行度低的模型推理速度更快。
  • Memory-economical:多分支结构是一种内存低效的架构,这是因为每个分支的结构都需要在Add/Concat之前保存,这会导致更大的峰值内存占用;而plain模型则具有更好的内存高效特征。
  • Flexible:多分支结构会限制CNN的灵活性,比如ResBlock会约束两个分支的tensor具有相同的形状;与此同时,多分支结构对于模型剪枝不够友好。

网络结构

Tensorflow2.4实现RepVGG_第1张图片

Tensorflow2.4实现RepVGG_第2张图片

代码实现

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Conv2D, BatchNormalization, GlobalAvgPool2D, Activation, Multiply,
    Add, Dense, Input
)

# ----------------- #
# 卷积+标准化
# ----------------- #
def conv_bn(filters, kernel_size, strides, padding, groups=1):
    def _conv_bn(x):
        x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
                   padding=padding, groups=groups, use_bias=False)(x)
        x = BatchNormalization()(x)
        return x
    return _conv_bn

# ----------------- #
# SE模块
# ----------------- #
def SE_block(x_0, r = 16):
    channels = x_0.shape[-1]
    x = GlobalAvgPool2D()(x_0)
    # (?, ?) -> (?, 1, 1, ?)
    x = x[:, None, None, :]
    # 用2个1x1卷积代替全连接
    x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
    x = Activation('sigmoid')(x)
    x = Multiply()([x_0, x])
    
    return x

# ----------------- #
# RepVGG模块
# ----------------- #
def RepVGGBlock(filters, kernel_size, strides=1, padding='valid', dilation=1, groups=1, deploy=False, use_se=False):
    def _RepVGGBlock(inputs):
        if deploy:
            if use_se:
                x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, 
                           padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
                x = SE_block(x)
                x = Activation('relu')(x)
            else:
                x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, 
                           padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
                x = Activation('relu')(x)
            return x
        
        if inputs.shape[-1] == filters and strides == 1:
            if use_se:
                id_out = BatchNormalization()(inputs)
                x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
                x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
                x3 = Add()([id_out, x1, x2])
                x4 = SE_block(x3)
                return Activation('relu')(x4)
            else:
                id_out = BatchNormalization()(inputs)
                x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
                x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
                x3 = Add()([id_out, x1, x2])
                return Activation('relu')(x3)
        else:
            if use_se:
                
                x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
                x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
                x3 = Add()([x1, x2])
                x4 = SE_block(x3)
                return Activation('relu')(x4)
            else:
                x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
                x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
                x3 = Add()([x1, x2])
                
                return Activation('relu')(x3)
        
    return _RepVGGBlock

# ----------------- #
# RepVGG模块的堆叠
# ----------------- #
def make_stage(planes, num_blocks, stride_1,deploy,use_se, override_groups_map=None):
    def _make_stage(x):
        cur_layer_id=1
        strides = [stride_1] + [1]*(num_blocks-1)
        for stride in strides:
            cur_groups = override_groups_map.get(cur_layer_id, 1)
            x = RepVGGBlock(filters=planes, kernel_size=3, strides=stride, padding='same',
                            groups=cur_groups, deploy=deploy, use_se=use_se)(x)
            cur_layer_id += 1
        return x
    return _make_stage

# ----------------- #
# RepVGG网络
# ----------------- #
def RepVGG(x, num_blocks, classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False):
    override_groups_map = override_groups_map or dict()
    in_planes = min(64, int(64 * width_multiplier[0]))
    out = RepVGGBlock(filters=in_planes, kernel_size=3, strides=2, padding='same', deploy=deploy, use_se=use_se)(x)
    out = make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
    out = make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
    out = make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
    out = make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
    out = GlobalAvgPool2D()(out)
    out = Dense(classes)(out)
    return out

optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}

def RepVGG_A0(inputs,classes=1000, deploy=False):
    return RepVGG(inputs, num_blocks=[2, 4, 14, 1], classes=classes,
                  width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)

def create_RepVGG_A1(x, deploy=False):
    return RepVGG(x, num_blocks=[2, 4, 14, 1], classes=1000,
                  width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
    
def create_RepVGG_A2(x, deploy=False):
    return RepVGG(x, num_blocks=[2, 4, 14, 1], num_classes=1000,
                  width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)

def create_RepVGG_B0(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)

def create_RepVGG_B1(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)

def create_RepVGG_B1g2(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)

def create_RepVGG_B1g4(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)


def create_RepVGG_B2(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)

def create_RepVGG_B2g2(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)

def create_RepVGG_B2g4(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)


def create_RepVGG_B3(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)

def create_RepVGG_B3g2(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)

def create_RepVGG_B3g4(x, deploy=False):
    return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)

def create_RepVGG_D2se(x, deploy=False):
    return RepVGG(x, num_blocks=[8, 14, 24, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True)

if __name__ == '__main__':
    inputs = Input(shape=(224,224,3))
    classes = 1000
    model = Model(inputs=inputs, outputs=RepVGG_A0(inputs))
    model.summary()
            
             

~欢迎更正

你可能感兴趣的:(深度学习,tensorflow,深度学习,cnn)