原文地址:MultiResUNet : Rethinking the U-Net Architecture for Multimodal Biomedical Image Segmentation paper
多分辨率分析来扩展U-Net的最简单方法是将3×3和7×7卷积运算与5×5卷积运算并行地合并,如图3a所示。
论文中使用一系列更小,更轻便的3×3卷积块来分解更大,更苛刻的5×5和7×7卷积层,如图3b所示。2个3×3卷积块的输出有效地近似5×5卷积运算,3个3×3卷积块的输出有效地近似7×7卷积运算。
最终MultiResUnet使用了三个3X3的卷积替换了unet中的模块,并且引入了1X1卷积层,添加了剩余连接,使模型能够理解一些其他空间信息。
引入残差连接,不是简单地将特征图从编码器级连接到解码器级,而是先将它们穿过带有残差连接的卷积层链,然后再与解码器特征连接。
论文中的结果展现了在各个方面MultiResUnet都展现出了比U-net更好的提升,但我在复现过程中发现当数据集不具有挑战性的时候,也就是分割较为完善的时候,U-net展现了更好的表现。
def mlti_res_block(inputs, filter_size1, filter_size2, filter_size3, filter_size4):
cnn1 = Conv2D(filter_size1, (3, 3), padding='same', activation="relu")(inputs)
cnn2 = Conv2D(filter_size2, (3, 3), padding='same', activation="relu")(cnn1)
cnn3 = Conv2D(filter_size3, (3, 3), padding='same', activation="relu")(cnn2)
cnn = Conv2D(filter_size4, (1, 1), padding='same', activation="relu")(inputs)
concat = Concatenate()([cnn1, cnn2, cnn3])
add = Add()([concat, cnn])
return add
def res_path(inputs, filter_size, path_number):
def block(x, fl):
cnn1 = Conv2D(filter_size, (3, 3), padding='same', activation="relu")(inputs)
cnn2 = Conv2D(filter_size, (1, 1), padding='same', activation="relu")(inputs)
add = Add()([cnn1, cnn2])
return add
cnn = block(inputs, filter_size)
if path_number <= 3:
cnn = block(cnn, filter_size)
if path_number <= 2:
cnn = block(cnn, filter_size)
if path_number <= 1:
cnn = block(cnn, filter_size)
return cnn
def multi_res_u_net(pretrained_weights=None, input_size=(256, 256, 1), lr=0.001):
inputs = Input(input_size)
res_block1 = mlti_res_block(inputs, 8, 17, 26, 51)
pool1 = MaxPool2D()(res_block1)
res_block2 = mlti_res_block(pool1, 17, 35, 53, 105)
pool2 = MaxPool2D()(res_block2)
res_block3 = mlti_res_block(pool2, 31, 72, 106, 209)
pool3 = MaxPool2D()(res_block3)
res_block4 = mlti_res_block(pool3, 71, 142, 213, 426)
pool4 = MaxPool2D()(res_block4)
res_block5 = mlti_res_block(pool4, 142, 284, 427, 853)
upsample = UpSampling2D()(res_block5)
res_path4 = res_path(res_block4, 256, 4)
concat = Concatenate()([upsample, res_path4])
res_block6 = mlti_res_block(concat, 71, 142, 213, 426)
upsample = UpSampling2D()(res_block6)
res_path3 = res_path(res_block3, 128, 3)
concat = Concatenate()([upsample, res_path3])
res_block7 = mlti_res_block(concat, 31, 72, 106, 209)
upsample = UpSampling2D()(res_block7)
res_path2 = res_path(res_block2, 64, 2)
concat = Concatenate()([upsample, res_path2])
res_block8 = mlti_res_block(concat, 17, 35, 53, 105)
upsample = UpSampling2D()(res_block8)
res_path1 = res_path(res_block1, 32, 1)
concat = Concatenate()([upsample, res_path1])
res_block9 = mlti_res_block(concat, 8, 17, 26, 51)
sigmoid = Conv2D(1, (1, 1), padding='same', activation="sigmoid")(res_block9)
model = Model(inputs, sigmoid)
model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
if (pretrained_weights):
model.load_weights(pretrained_weights)
return model