2018年6月的文章,DLMIA2018会议,作者单位Department of Biomedical Informatics, Arizona State University
原作者在知乎上给出了对文章的思路总结
文章对Unet改进的点主要是skip connection。作者认为skip connection 直接将unet中encoder的浅层特征与decoder的深层特征结合是不妥当的,会产生semantic gap。整篇文章的一个假设就是,当所结合的浅层特征与深层特征是semantically similar时,网络的优化问题就会更简单,因此文章对skip connection的改进就是想bridge/reduce 这个semantic gap。
作为参考,先附一张原始Unet结构图如下
理解这篇文章的关键就是看懂文中的这张图。其中黑色部分代表的就是原始Unet结构,绿色代表添加的卷积层,蓝色代表改进的skip connection。
文章给出的公式可以较好地表示图中的结构:
其中 H ( ⋅ ) H ( ⋅ ) H ( ⋅ ) H(⋅)H(⋅) \mathcal{H}(\cdot) H(⋅)H(⋅)H(⋅)X2,1拼接之后,再经过一次conv与relu得到。
采用这种改进的Unet相比起相同参数量的原始Unet,作者在4种不同的数据集上都得到了更好的分割效果。
除了对skip connection进行改进之外,文章还引入了deep supervision的思路。网络的loss函数是由不同层得到的分割图的loss的平均,每层的loss函数为DICE LOSS和Binary cross-entropy LOSS之和,如下所示。作者认为引入DSN(deep supervision net)后,通过model pruning(模型剪枝,如图2(c)所示)能够实现模型的两种模式:高精度模式和高速模式。
源码如下:
#https://github.com/MrGiovanni/Nested-UNet/blob/master/model.py
import keras
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras.layers import Input, merge, Conv2D, ZeroPadding2D, UpSampling2D, Dense, concatenate, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalAveragePooling2D, MaxPooling2D
from keras.layers.core import Dense, Dropout, Activation
from keras.layers import BatchNormalization, Dropout, Flatten, Lambda
from keras.layers.advanced_activations import ELU, LeakyReLU
from keras.optimizers import Adam, RMSprop, SGD
from keras.regularizers import l2
from keras.layers.noise import GaussianDropout
import numpy as np
smooth = 1.
dropout_rate = 0.5
act = “relu”
########################################
# 2D Standard
########################################
def standard_unit(input_tensor, stage, nb_filter, kernel_size=3):
x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor)
x = Dropout(dropout_rate, name='dp'+stage+'_1')(x)
x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x)
x = Dropout(dropout_rate, name='dp'+stage+'_2')(x)
return x
########################################
“”"
Standard UNet++ [Zhou et.al, 2018]
Total params: 9,041,601
“”"
def Nest_Net(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False):
nb_filter = [32,64,128,256,512]
# Handle Dimension Ordering for different backends
global bn_axis
if K.image_dim_ordering() == 'tf':
bn_axis = 3
img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
else:
bn_axis = 1
img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')
conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)
conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)
up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0])
conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)
up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1])
up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0])
conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)
up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2])
up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1])
up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0])
conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])
up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])
up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])
up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])
up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])
nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2)
nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3)
nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4)
nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)
if deep_supervision:
model = Model(input=img_input, output=[nestnet_output_1,
nestnet_output_2,
nestnet_output_3,
nestnet_output_4])
else:
model = Model(input=img_input, output=[nestnet_output_4])
return model