keras版yolov4 核心代码拆解详细备注

实现YOLOV4网络的3大核心结构及实现源码详细备注
目录:
1、CSPDarknet53 主干特征提取网络
2、加强特征提取网络-1-SPP
3、加强特征提取网络-2-PANet

1.0 核心结构图
keras版yolov4 核心代码拆解详细备注_第1张图片
1.1 CSPDarknet53


 from functools import wraps
 from keras import backend as K
 from keras.layers import Conv2D, Add, ZeroPadding2D, Upsampling2D, Concatence, MaxPooling2D,Layer
 from keras.layers.advanced_activations import LeakyReLU
 from keras.layer.normalization import BatchNormalization
 from keras.regularizers import l2
 from utils.utils import compose


# 激活函数模块
 class Mish(Layer):
     def __init__(self, **kwargs):
         super(Mish, self).__ini__(**kwargs)
         self.support_masking = True

     def call(self, inputs):
         return inputs * K.tanh(K.softplus(inputs))

     def get_config(self):
         config = super(Mish,self).get_config()
         return config

     def compute_output_shape(self, input_shape):
         return input_shape

 """
 单次卷积
 """
@wraps(Conv2D)
def DarknetConv2D(*args, **kwargs):
    darknet_con_kwargs = {'kernel_regularizer':l2(5e-4)}  #  比较普通卷积, 区别就是此处的L2正则
    darknet_con_kwargs['padding'] = 'valid'if kwargs.get('strides')==(2,2) else 'same'
    darknet_con_kwargs.update(kwargs)
    return Conv2D(*args, **darknet_con_kwargs)

"""
卷积块
"""
def DarknetConv2D_BN_Mish(*args, **kwargs):
    no_bias_kwargs = {'use_bias':False}
    no_bias_kwargs.update(kwargs)
    return compose(
        DarknetConv2D(*args, **no_bias_kwargs),
        BatchNormalization(),
        Mish()
    )

"""
CSPdarknet结构块
"""
def resblock_body(x, num_filter, num_blocks, all_narrow=True):
    # 高 宽 压缩
    preconv1 = ZeroPadding2D(((1,0),(1,0)))(x)
    preconv1 =DarknetConv2D_BN_Mish(num_filter, (3,3), strides=(2,2))(preconv1)

    # 残差边
    shortconv = DarknetConv2D_BN_Mish(num_filter//2 if all_narrow else num_filter, (1,1))(preconv1)


    # 主干部分卷积
    mainconv = DarknetConv2D_BN_Mish(num_filter//2 if all_narrow else num_filter, (1,1))(preconv1)

    # 1X1 卷积对通道数进行整合 ->3X3 卷积提取他特征,使用残差结构
    for i in range(num_blocks):
        y = compose(
            DarknetConv2D_BN_Mish(num_filter//2 ,(1,1)),
            DarknetConv2D_BN_Mish(num_filter//2 if all_narrow else num_filter, (3,3))
        )(mainconv)
        mainconv = Add()([mainconv,y])

    # 1X1卷积后和残差边堆叠
    postconv =DarknetConv2D_BN_Mish(num_filter//2 if all_narrow else num_filter, (1,1))(mainconv)
    route = Concatence()([postconv,shortconv])

    # 最后通过 1x1 卷积 对通道数进行整合
    return DarknetConv2D_BN_Mish(num_filter, (1,1))(route)


# darknet53主体部分
def darknet_body(x):
    x = DarknetConv2D_BN_Mish(32, (3,3))(x)
    x =resblock_body(x, 64, 1, False) # 1,2,8,8,4 这就是残差块的堆叠数
    x = resblock_body(x, 128, 2)
    x = resblock_body(x, 256, 8)   #256 ,512,1024 等 这就是通道数.此特征  52X52X256(416X416X3的输入)
    feat1 = x
    x =resblock_body(x,512,8) # 此特征层 26X26X512(416X416X3的输入)
    feat2 = x
    x = resblock_body(x, 1024, 4)   #此特征层 13X13X512 (416X416X3的输入前提)
    feat3 = x
    return feat1,feat2, feat3

你可能感兴趣的:(yolov4源码开箱)