RepVGG是清华大学&旷视科技等提出的一种新颖的CNN设计范式,避免了VGG类方法训练所得精度低的问题,又保持了VGG方案的高效推理优点。
论文地址: https://arxiv.org/abs/2101.03697
博客拿出我觉得重要的两点方法。
ResNet结构中ResBlock使用的是 y = x + f ( x ) y=x+f(x) y=x+f(x),尽管多分支结构对于推理不友好,但对于训练友好,作者将RepVGG设计为训练时的多分支,推理时单分支结构。作者参考ResNet的identity与 1 × 1 1\times1 1×1分支,设计了如下形式模块:
y = x + g ( x ) + f ( x ) y=x+g(x)+f(x) y=x+g(x)+f(x)
其中:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Conv2D, BatchNormalization, GlobalAvgPool2D, Activation, Multiply,
Add, Dense, Input
)
# ----------------- #
# 卷积+标准化
# ----------------- #
def conv_bn(filters, kernel_size, strides, padding, groups=1):
def _conv_bn(x):
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, groups=groups, use_bias=False)(x)
x = BatchNormalization()(x)
return x
return _conv_bn
# ----------------- #
# SE模块
# ----------------- #
def SE_block(x_0, r = 16):
channels = x_0.shape[-1]
x = GlobalAvgPool2D()(x_0)
# (?, ?) -> (?, 1, 1, ?)
x = x[:, None, None, :]
# 用2个1x1卷积代替全连接
x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
x = Activation('relu')(x)
x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
x = Activation('sigmoid')(x)
x = Multiply()([x_0, x])
return x
# ----------------- #
# RepVGG模块
# ----------------- #
def RepVGGBlock(filters, kernel_size, strides=1, padding='valid', dilation=1, groups=1, deploy=False, use_se=False):
def _RepVGGBlock(inputs):
if deploy:
if use_se:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = SE_block(x)
x = Activation('relu')(x)
else:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = Activation('relu')(x)
return x
if inputs.shape[-1] == filters and strides == 1:
if use_se:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
return Activation('relu')(x3)
else:
if use_se:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
return Activation('relu')(x3)
return _RepVGGBlock
# ----------------- #
# RepVGG模块的堆叠
# ----------------- #
def make_stage(planes, num_blocks, stride_1,deploy,use_se, override_groups_map=None):
def _make_stage(x):
cur_layer_id=1
strides = [stride_1] + [1]*(num_blocks-1)
for stride in strides:
cur_groups = override_groups_map.get(cur_layer_id, 1)
x = RepVGGBlock(filters=planes, kernel_size=3, strides=stride, padding='same',
groups=cur_groups, deploy=deploy, use_se=use_se)(x)
cur_layer_id += 1
return x
return _make_stage
# ----------------- #
# RepVGG网络
# ----------------- #
def RepVGG(x, num_blocks, classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False):
override_groups_map = override_groups_map or dict()
in_planes = min(64, int(64 * width_multiplier[0]))
out = RepVGGBlock(filters=in_planes, kernel_size=3, strides=2, padding='same', deploy=deploy, use_se=use_se)(x)
out = make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = GlobalAvgPool2D()(out)
out = Dense(classes)(out)
return out
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}
def RepVGG_A0(inputs,classes=1000, deploy=False):
return RepVGG(inputs, num_blocks=[2, 4, 14, 1], classes=classes,
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A1(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A2(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)
def create_RepVGG_B0(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B1g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B2g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B2g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B3(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B3g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B3g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_D2se(x, deploy=False):
return RepVGG(x, num_blocks=[8, 14, 24, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True)
if __name__ == '__main__':
inputs = Input(shape=(224,224,3))
classes = 1000
model = Model(inputs=inputs, outputs=RepVGG_A0(inputs))
model.summary()
~欢迎更正