MultiResUnet概述

原文地址:MultiResUNet : Rethinking the U-Net Architecture for Multimodal Biomedical Image Segmentation paper

创新点

  1. 将u-net中的两个3X3的卷积替换成3X3,7X7卷积运算与5X5卷积运算并行合并,使用多分辨率思路替换传统卷积层。
  2. 使用res path替换传统u-net中的简单的跳过连接。
  3. 在具有挑战性的训练集有着卓越的提高。

multiblock

多分辨率分析来扩展U-Net的最简单方法是将3×3和7×7卷积运算与5×5卷积运算并行地合并,如图3a所示。
论文中使用一系列更小,更轻便的3×3卷积块来分解更大,更苛刻的5×5和7×7卷积层,如图3b所示。2个3×3卷积块的输出有效地近似5×5卷积运算,3个3×3卷积块的输出有效地近似7×7卷积运算。
最终MultiResUnet使用了三个3X3的卷积替换了unet中的模块,并且引入了1X1卷积层,添加了剩余连接,使模型能够理解一些其他空间信息。
MultiResUnet概述_第1张图片

res path

引入残差连接,不是简单地将特征图从编码器级连接到解码器级,而是先将它们穿过带有残差连接的卷积层链,然后再与解码器特征连接。
MultiResUnet概述_第2张图片

实验测试

MultiResUnet概述_第3张图片
论文中的结果展现了在各个方面MultiResUnet都展现出了比U-net更好的提升,但我在复现过程中发现当数据集不具有挑战性的时候,也就是分割较为完善的时候,U-net展现了更好的表现。
在这里插入图片描述

模型代码

def mlti_res_block(inputs, filter_size1, filter_size2, filter_size3, filter_size4):
    cnn1 = Conv2D(filter_size1, (3, 3), padding='same', activation="relu")(inputs)
    cnn2 = Conv2D(filter_size2, (3, 3), padding='same', activation="relu")(cnn1)
    cnn3 = Conv2D(filter_size3, (3, 3), padding='same', activation="relu")(cnn2)

    cnn = Conv2D(filter_size4, (1, 1), padding='same', activation="relu")(inputs)

    concat = Concatenate()([cnn1, cnn2, cnn3])
    add = Add()([concat, cnn])

    return add

def res_path(inputs, filter_size, path_number):
    def block(x, fl):
        cnn1 = Conv2D(filter_size, (3, 3), padding='same', activation="relu")(inputs)
        cnn2 = Conv2D(filter_size, (1, 1), padding='same', activation="relu")(inputs)

        add = Add()([cnn1, cnn2])

        return add

    cnn = block(inputs, filter_size)
    if path_number <= 3:
        cnn = block(cnn, filter_size)
        if path_number <= 2:
            cnn = block(cnn, filter_size)
            if path_number <= 1:
                cnn = block(cnn, filter_size)

    return cnn


def multi_res_u_net(pretrained_weights=None, input_size=(256, 256, 1), lr=0.001):
    inputs = Input(input_size)

    res_block1 = mlti_res_block(inputs, 8, 17, 26, 51)
    pool1 = MaxPool2D()(res_block1)

    res_block2 = mlti_res_block(pool1, 17, 35, 53, 105)
    pool2 = MaxPool2D()(res_block2)

    res_block3 = mlti_res_block(pool2, 31, 72, 106, 209)
    pool3 = MaxPool2D()(res_block3)

    res_block4 = mlti_res_block(pool3, 71, 142, 213, 426)
    pool4 = MaxPool2D()(res_block4)

    res_block5 = mlti_res_block(pool4, 142, 284, 427, 853)
    upsample = UpSampling2D()(res_block5)

    res_path4 = res_path(res_block4, 256, 4)
    concat = Concatenate()([upsample, res_path4])

    res_block6 = mlti_res_block(concat, 71, 142, 213, 426)
    upsample = UpSampling2D()(res_block6)

    res_path3 = res_path(res_block3, 128, 3)
    concat = Concatenate()([upsample, res_path3])

    res_block7 = mlti_res_block(concat, 31, 72, 106, 209)
    upsample = UpSampling2D()(res_block7)

    res_path2 = res_path(res_block2, 64, 2)
    concat = Concatenate()([upsample, res_path2])

    res_block8 = mlti_res_block(concat, 17, 35, 53, 105)
    upsample = UpSampling2D()(res_block8)

    res_path1 = res_path(res_block1, 32, 1)
    concat = Concatenate()([upsample, res_path1])

    res_block9 = mlti_res_block(concat, 8, 17, 26, 51)
    sigmoid = Conv2D(1, (1, 1), padding='same', activation="sigmoid")(res_block9)

    model = Model(inputs, sigmoid)
    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])

    if (pretrained_weights):
        model.load_weights(pretrained_weights)

    return model

你可能感兴趣的:(MultiResUnet概述)