deeplabV3+网络实现及解析

简介

  DeepLab V3+通过添加一个简单而有效的解码器模块, 从而扩展了DeepLab V3,以优化分割结果,尤其是沿对象边界的分割结果。 我们进一步探索Xception模型,并将深度可分离卷积应用于 Atrous空间金字塔池和解码器模块,从而形成更快、更强大的编码器-解码器网络。

网络结构

  DeepLab V3+的主要模型结构如下图:

deeplabV3+网络实现及解析_第1张图片

  这里的结构分为了两部分:encoder和decoder。encoder中连接的第一个模块是DCNN, 他代表的是用于提取图片特征的主干网络,DCNN右边是一个ASPP网络,他用一个1*1的卷积、3个3*3的 空洞卷积和一个全局池化来对主干网络的输出进行处理。然后再将其结果都连接起来并用一个1*1的卷积 来缩减通道数。

  然后是decoder部分。他会将主干网络的中间输出和ASPP的输出变换成相同形状,然后将 其连接在一起,再进行3*3的卷积。最后利用这个结果来继续语义分割。

空洞卷积

  在deeplab中为了扩大感受野使用的空洞卷积,空洞卷积与普通卷积的主要区别如下图:

deeplabV3+网络实现及解析_第2张图片

 上图中(a)即普通卷积:取相邻的点来进行卷积。图(b)和(c)为空洞卷积,其取的是相隔 指定数量的点来进行卷积。

 相较于普通卷积,空洞卷积能够在不增加参数的情况下增加感受野,如上图(b),同样是3*3的 卷积核,(b)的感受野却与7*7的感受野相同。

深度可分离卷积

  深度可分离卷积可以使用更少的参数实现与普通卷积相同的效果,两者的区别如下图:

deeplabV3+网络实现及解析_第3张图片

  深度可分离卷积分为两部分:depthwise卷积和pointwise卷积。 depthwise卷积(图中(b)):对每个输入通道单独使用一个卷积核处理。 pointwise卷积(图中(c)):1*1卷积,用于将depthwise卷积的输出组合起来。

主干网络

  DeepLab V3+使用的主干网络是Xception网络,并对该网络进行了一定的改动, 其网络结构如下:

deeplabV3+网络实现及解析_第4张图片

  这里主要分为三个部分:Entry flow、Middle flow 和 Exit flow。

  在Entry flow中,首先是两个3*3的卷积,然后是三个用深度可分离卷积代替 普通3*3卷积的残差模块。然后是Middle flow, 这是一个左侧没有卷积的残差模块,同样 这里也是用的深度可分离卷积,然后重复16次。最后是Exit flow,这里就是一个残差模块和 三个深度可分离卷积。

keras实现

整体代码

  keras具体实现如下:

def deeplab(input_shape=(512,512,3),num_class=1):
    input= KL.Input(shape=input_shape)

    #   xception特征层
    x,atrous,skip1 = Xception(input)

    size_before = K.int_shape(x)

    b0 = KL.Conv2D(256,(1,1),padding="same",use_bias=False,name="aspp0")(x)
    b0 = KL.BatchNormalization(name="aspp0_BN",epsilon=1e-5)(b0)
    b0 = KL.Activation("relu",name="aspp0_activation")(b0)

    b1 = SepConv_BN(x,256,"aspp1",rate=atrous[0],depth_activation=True,epsilon=1e-5)
    b2 = SepConv_BN(x,256,"aspp2",rate=atrous[1],depth_activation=True,epsilon=1e-5)
    b3 = SepConv_BN(x,256,"aspp3",rate=atrous[2],depth_activation=True,epsilon=1e-5)

    b4 = KL.GlobalAveragePooling2D()(x)
    b4 = KL.Lambda(lambda x:K.expand_dims(x,1))(b4)
    b4 = KL.Lambda(lambda x:K.expand_dims(x,1))(b4)
    b4 = KL.Conv2D(256,(1,1),padding="same",use_bias=False,name="image_pooling")(b4)
    b4 = KL.BatchNormalization(name="image_pooling_BN",epsilon=1e-5)(b4)
    b4 = KL.Activation("relu")(b4)

    b4 = KL.Lambda(lambda x : tf.image.resize(x,size_before[1:3]))(b4)

    x = KL.Concatenate()([b4,b0,b1,b2,b3])

    x = KL.Conv2D(256,(1,1),padding="same",use_bias=False,name="concat_projection")(x)
    x = KL.BatchNormalization(name="concat_projection_BN",epsilon=1e-5)(x)
    x = KL.Activation("relu")(x)
    x = KL.Dropout(0.1)(x)

    skip_size = K.int_shape(skip1)
    x = KL.Lambda( lambda xx:tf.image.resize(xx,skip_size[1:3]))(x)

    dec_skip1 = KL.Conv2D(48,(1,1),padding="same",use_bias=False,name="feature_projection0")(skip1)
    dec_skip1 = KL.BatchNormalization(name="feature_projection0_BN",epsilon=1e-5)(dec_skip1)
    dec_skip1 = KL.Activation("relu")(dec_skip1)

    x = KL.Concatenate()([x,dec_skip1])
    x = SepConv_BN(x,256,"decoder_conv0",depth_activation=True,epsilon=1e-5)
    x = SepConv_BN(x,256,"decoder_conv1",depth_activation=True,epsilon=1e-5)

    size_before3 = K.int_shape(input)
    x = KL.Conv2D(num_class,(1,1),padding="same")(x)
    x = KL.Lambda(lambda xx:tf.image.resize(xx,size_before3[1:3]))(x)
    x = KL.Activation("sigmoid")(x)

    model= Model(inputs=input,outputs=x)
    return model

  首先是第5行调用的Xception方法,这个方法是Xception主干网络的实现, 他返回了三个参数:x,atrous,skip1。其中x是Xception的最终输出,atrous是ASPP网络 使用的空洞卷积的参数,skip1是Xception的一个浅层输出特征。

  然后是第9行到第31行,这里是ASPP网络的实现。首先是第9行到第11行,这里就 是一个1*1的卷积层和对应的标准化与激活函数。然后是第13行到第15行是三个3*3的空洞卷积, 这里调用的SepConv_BN方法是一个实现支持空洞卷积的深度可分离卷积。然后是第17到第24行, 这里是一个全局池化层,池化过后先添加两个维度,使其维度与其他层相同,然后再使用一个1*1 的卷积层缩减其通道数,最后在将其形状resize到与其他特征相同。最后是第26行到第31行,这 里将上述的5个特征连接在一起,然后使用1*1的卷积缩减通道数。

  然后是第33行到第47行,这里对应的是图中decoder模块。首先是第34行,这里是在 处理encoder的最终输出,他会将其resize成encoder的浅层特征(skip1)的形状,然后是第36 行到第38行,这里使用了一个1*1的卷积层对其进行通道数缩减。然后是第40行到第42行,这里将两 个输出连接到一起,然后再添加两个深度可分离卷积。最后是第44行到第47行,他首先是用一个1*1 的卷积层来代替全连接层来预测每个像素点的分类概率,然后将输出的形状resize到原图的大小,最 后再添加一个sigmoid激活函数(激活函数可以根据需求变化)。

主干网络Xception

  上文提到的主干网络Xception的实现如下:

def Xception(image,alpha=1.,downsample_factor=16):


    if downsample_factor == 8:
        entry_block3_stride = 1
        middle_block_rate = 2
        exit_block_rates = (2,4)
        atrous_rates = (12,24,36)

    elif downsample_factor == 16:
        entry_block3_stride = 2
        middle_block_rate = 1
        exit_block_rates = (1,2)
        atrous_rates = (6,12,18)

    x = KL.Conv2D(32,(3,3),strides=(2,2),name="entry_flow_conv1_1",use_bias=False,padding="same")(image)
    x = KL.BatchNormalization(name="entry_flow_conv1_1_BN")(x)
    x = KL.Activation("relu")(x)

    x = _conv2d_same(x,64,"entry_flow_conv1_2",kernel_size=3,stride=1)
    x = KL.BatchNormalization(name="entry_flow_conv1_2_BN")(x)
    x = KL.Activation("relu")(x)

    x = _xception_block(x,[128,128,128],"entry_flow_block1",skip_connection_type="conv",stride=2,depth_activation=False)
    x,skip1 = _xception_block(x,[256,256,256],"entry_flow_block2",skip_connection_type="conv",stride=2,depth_activation=False,return_skip=True)
    x = _xception_block(x,[728,728,728],"entry_flow_block3",skip_connection_type="conv",stride=entry_block3_stride,depth_activation=False)

    for i in range(16):
        x = _xception_block(x,[728,728,728],"middle_flow_unit_{}".format(i+1),skip_connection_type="sum",stride=1,rate=middle_block_rate,depth_activation=False)

    x = _xception_block(x,[728,1024,1024],"exit_flow_block1",skip_connection_type="conv",stride=1,rate=exit_block_rates[0],depth_activation=False)
    x = _xception_block(x,[1536,1536,2048],"exit_flow_block2",skip_connection_type="none",stride=1,rate=exit_block_rates[1],depth_activation=True)

    return x,atrous_rates,skip1

  首先是第4行到第14行,这里会根据downsample_factor来初始化一系列参数。然后 是第16行到第22行,这里是两个3*3的卷积层。然后是第24行到第26行,这里是三个_xception_block ,这个模块即上文提到的用深度可分离卷积代替普通卷积的残差模块。从第16行到第26行,这里对应的是 Xception中的entry flow。然后是第28行和第29行,这里对应的middle flow,即重复16次残差模 块,最后第31行和第32行,这里对应的exit flow,这里使用的是两个残差模块。

  这里提到的_xception_block方法内容如下:

def _xception_block(inputs,depth_list,prefix,skip_connection_type,stride,
                    rate=1,depth_activation=False,return_skip=False):

    residual = inputs

    for i in range(3):
        residual = SepConv_BN(residual,depth_list[i],prefix+"_separable_conv{}".format(i+1),stride= stride if i == 2 else 1,
                              rate=rate,depth_activation=depth_activation)
        if i == 1 :
            skip = residual

    if skip_connection_type == "conv":
        shortcut = _conv2d_same(inputs,depth_list[-1],prefix+"_shortcut",kernel_size=1,stride=stride)
        shortcut = KL.BatchNormalization(name=prefix+"_shortcut_BN")(shortcut)
        outputs = KL.Add()([residual,shortcut])

    elif skip_connection_type == "sum":
        outputs = KL.Add()([residual,inputs])
    elif skip_connection_type == "none":
        outputs = residual

    if return_skip:
        return outputs,skip
    else:
        return outputs

  首先是第6行到第10行,这里循环了3次SepConv_BN方法,这个方法实现了深度可 分离卷积。然后是第12行到第19行,这里会根据skip_connection_type的类型来执行不同的 操作,首先是conv的情况,代表是残差边有卷积的情况,这里会先执行一个卷积再将两者相加,然 后是sum的情况,代表残差边没有卷积的情况,最后是none,代表不用残差。

  这里使用的SepConv_BN方法如下:

def SepConv_BN(x,filters,prefix,stride=1,kernel_size=3,rate=1,depth_activation=False,epsilon=1e-3):

    if stride == 1:
        depth_padding = "same"

    else:
        kernel_size_effective = kernel_size + (kernel_size -1 ) * (rate -1)
        pad_total = kernel_size_effective - 1
        pad_beg = pad_total // 2
        pad_end = pad_total - pad_beg
        x = KL.ZeroPadding2D((pad_beg,pad_end))(x)
        depth_padding = "valid"

    if not  depth_activation:
        x = KL.Activation("relu")(x)


    x = KL.DepthwiseConv2D((kernel_size,kernel_size),strides=(stride,stride),dilation_rate=(rate,rate),
                           padding=depth_padding,use_bias=False,name=prefix+"_depthwise")(x)
    x = KL.BatchNormalization(name=prefix+"_depthwis_BN",epsilon=epsilon)(x)
    if depth_activation:
        x = KL.Activation("relu")(x)

    x = KL.Conv2D(filters,(1,1),padding="same",use_bias=False,name=prefix+"_pointwise")(x)
    x = KL.BatchNormalization(name=prefix+ "_pointwise_BN",epsilon=epsilon)(x)
    if depth_activation:
        x= KL.Activation("relu")(x)

    return x

  首先是第3行到第12行,这里会根据步长的不同来选择不同padding方式。然后是第14行和 第15行,这里会根据depth_activation来决定是否执行激活函数。然后是第19行到第20行,这里执行 了一个depthwise卷积,然后是第24行和第25行,这里执行了一个1*1的卷积,即pointwise卷积。

模型训练与预测

模型训练

  这里的模型是使用keras实现的,keras模型的训练很简单,只需要调用fit方法 或fit_generator方法便可。

  在训练前只需要处理好模型的输入与输出便可。语义分割需要的如下图所示:

deeplabV3+网络实现及解析_第5张图片

  上图是一个语义分割的数据集,左边的图片是数据集的输入,右边的是输出。 这是一个二分类的语义分割任务,右边白色的标签为1,黑色的标签为0。

  主要训练代码如下:

def train():
    #图片与标签路径
    train_csv = "XXX"
    train_image = "XXX"
    
    #模型保存路径
    checkpoint_path = "../savemodel/deeplab_tld2.{epoch:02d}-{val_loss:.2f}.h5"
    #除了输入与输出
    trains, vals, test = utils.get_HSAgenerator_all(train_csv, train_image,train_batch_size=8)

    cb = [
        ModelCheckpoint(checkpoint_path, verbose=0),
        TensorBoard(log_dir="../logs/deeplab_tld2")
    ]


    model = deeplab(input_shape=(256, 256, 3))
    lr = 1e-4
    a = Adam(learning_rate=lr)
    model.compile(loss=utils.lossdif2, optimizer=a, metrics="acc")

    model.fit_generator(trains, steps_per_epoch=2625, epochs=50, validation_data=vals,
                        validation_steps=375, callbacks=cb)

  训练的代码主要如上所示,这里主要是使用的fit_generator方法来训练。 他要求传入的输入是一个generator类,这个类可以使用yield关键字来快速实现。

  然后是callbacks,这里定义了两个类:ModelCheckpoint和TensorBoard。 ModelCheckpoint的主要作用是在模型训练完成指定的epoch后保存模型,默认是每个批次都保存模型。 TensorBoard主要的作用是存储TensorBoard需要的数据。

  最后是模型编译需要的一些参数。损失函数一般可以交叉熵函数,optimizer可以使用adam或 sgd(随机梯度下降)等,metrics这里使用的acc(准确率)。

模型预测

  这里的模型是使用keras实现的,keras模型的预测也很简单。主要代码如下:

def predict_hsa():
   
    name = "XXX"
    model_path = "XXX"

    r = []

    m = deeplab(input_shape=(256, 256, 3))
    m.load_weights(model_path)

 
    threshold = 0.5
    image = cv2.imread(name)
    image = cv2.resize(image,(256,256))

    p = m.predict(np.array([image]))
    mask = p[0]
    mask = np.where(mask >= threshold, 1, 0).astype(np.uint8)
    mask = cv2.resize(mask, (512, 512))
    utils.show(image,mask)

  首先是第8行和第9行,这里先定义了deeplab模型,然后加载训练好的模型。然后是第12行到 第14行,这里主要是读取图片并resize成模型需求的形状。然后是第16行调用predict方法进行预测,最 后是第17行到第20行,这里在处理模型预测结果并显示。

  模型的预测效果主要如下:

deeplabV3+网络实现及解析_第6张图片

你可能感兴趣的:(深度学习,神经网络,tensorflow)