很早之前看了unet3+医学图像分割的论文,本来想直接去github找keras/Tensorflow的实现,奈何找到的似乎都和源码有一些出入,于是自己按照论文和源码写了一下,不过也不能保证和源码完全一致,发出来抛砖引玉。很多讲unet3+的博客都写的挺不错的,要想了解全文可以看看这篇翻译【UNet3+(UNet+++)论文解读 玖零猴】,这篇文章也简单讲一下自己的理解。
unet3+论文
源码(Pytorch)
简单来说,unet3+有三个特点:
1 跨尺度连接,防止语义在下采样/上采样之间存在损失
2 全尺度深监督,学习深层次的特征表示
3 为了消除医学图像中噪声导致的假阳性分割,提出一个分类指导模块
4 一个新的混合损失函数(TODO)
呃,前面三点其实各有槽点,后面再说
unet3+的网络结构如上图,总的来说还是非常易懂的,作者认为unet和unet++都没有做到跨尺度的特征图连接,于是想到将编码器不同尺度地信息传递到解码器,解码器中的信息也进行了跨层传递,以此减少信息丢失(真是简单粗暴=_=)。
以解码器3为例,解码器3融合了编码器1、2、3和解码器4、5的特征,这些特征通过最大池化(来自编码器的特征)或上采样(来自解码器的特征)调整到和解码器3一样的特征图大小,并且通过卷积层(源码里是卷积+BN+ReLu)将特征数调整到一致。这些拼接的特征图再经过一个卷积+BN+ReLu块输出特征就OK。
这张图解释了另外两个特点,一个是全尺度深监督,另一个是分类指导模块(CGM)。
全尺度深监督是针对所有解码器每一层的输出计算损失函数。
为了防止噪声导致的假阳性分割,作者提出了分类指导模块。分类指导模块是添加在网络瓶颈层(编码器底层,En5)的模块,这一层网络最深,特征图数量最多,且特征图最小,可能过滤掉了一定的噪声。作者在这一层后面添加了一个小的分类头(Dropout + Conv1x1 + Pooling + Sigmoid),这个分类头输出一个概率,表示输入图像中有无目标器官,将这个分类结果和分割头相乘,可以消除假阳性。
特点讲完了,说说槽点:
1 全尺度连接好是好,而且作者特地提到了,unet3+的参数是少于unet和unet++的,但实际上训练需要的时间和占用的内存好像都更多一些,似乎是因为unet3+用到了更多的卷积操作(比如,unet解码器每层只需要2次卷积,但看看上面的Fig.2,unet3+的每层解码器需要6次卷积)
2 还没想好
3 CGM只是一个简单的模块,在我自己的实验中,就算加了Dropout也很快就过拟合了,图像分割头的验证集损失还在降低,CGM这边的损失函数却已经不降反升了。
注:小孩子不懂事,代码写着玩的,不一定正确,如果有问题欢迎指出和讨论,转载请注明出处。
CGM输出这块的实现还是有待商榷的,我的代码里CGM和分割掩膜是分别输出的,所以后面要手动相乘一下。
import tensorflow as tf
import numpy as np
from keras.models import Model
from keras.layers import Conv2D, Input, concatenate, MaxPooling2D, UpSampling2D, Activation, BatchNormalization, LayerNormalization, Dropout, GlobalMaxPooling2D
# helper function to build unet3+
def normalization(input_tensor, normalization):
if normalization=='batch':
return(BatchNormalization()(input_tensor))
elif normalization=='layer':
return(LayerNormalization()(input_tensor))
elif normalization == None:
return input_tensor
else:
raise ValueError('Invalid normalization')
def conv2d_block(input_tensor, filters, kernel_size,
norm_type, use_residual, act_type='relu',
double_features = False, dilation=[1, 1]):
x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[0], use_bias=False, kernel_initializer='he_normal')(input_tensor)
x = normalization(x, norm_type)
x = Activation(act_type)(x)
if double_features:
filters *= 2
x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[1], use_bias=False, kernel_initializer='he_normal')(x)
x = normalization(x, norm_type)
if use_residual:
if K.int_shape(input_tensor)[-1] != K.int_shape(x)[-1]:
shortcut = Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(input_tensor)
shortcut = normalization(shortcut, norm_type)
x = add([x, shortcut])
else:
x = add([x, input_tensor])
x = Activation(act_type)(x)
return x
def down_layer_2d(input_tensor, down_pattern, filters, norm_type=None):
if down_pattern == 'maxpooling':
x = MaxPooling2D(pool_size=(2, 2))(input_tensor)
elif down_pattern == 'avgpooling':
x = AveragePooling2D(pool_size=(2, 2))(input_tensor)
elif down_pattern == 'conv':
x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', use_bias=False if norm_type is None else True, kernel_initializer='he_normal')(input_tensor)
normalization(x, norm_type)
elif down_pattern == 'normconv':
x = normalization(input_tensor, norm_type)
x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', kernel_initializer='he_normal')(x)
else:
raise ValueError('Invalid down_pattern')
return x
def conv_norm_act(input_tensor, filters, kernel_size , norm_type='batch', act_type='relu', dilation=1):
output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
output_tensor = normalization(output_tensor, normalization=norm_type)
output_tensor = Activation(act_type)(output_tensor)
return output_tensor
def aggregate(l1, l2, l3, l4, l5, filters, kernel_size, norm_type='batch', act_type='relu'):
out = concatenate([l1, l2, l3, l4, l5], axis = -1)
out = Conv2D(filters * 5, kernel_size, padding = 'same', use_bias=False if norm_type is not None else True, kernel_initializer = 'he_normal')(out)
out = normalization(out, norm_type)
out = Activation(act_type)(out)
return out
def cgm_block(input_tensor, class_num, dropout_rate = 0.):
x = Dropout(rate = dropout_rate)(input_tensor)
x = Conv2D(class_num, 1, padding='same', kernel_initializer='he_normal')(x)
# x = BatchNormalization()(x)
x = GlobalMaxPooling2D()(x) # 用全局最大池化代替原文中的自适应最大池化,这里的效果应该是一样的
x = Activation('sigmoid', name='cgm_output')(x)
# x = Lambda(lambda x: K.expand_dims(x, axis=1))(x)
# x = Lambda(lambda x: K.expand_dims(x, axis=1), name = 'cgm_output')(x)
# x = Reshape((batch_size, 1, 1, class_num))(x)
return x
# build unet3+ model
def unet3p_2d(input_shape, initial_features=32, kernel_size=3,
class_num=1, norm_type='batch', double_features=False,
use_residual=False, down_pattern='maxpooling', using_deep_supervision=True,
using_cgm=False, cgm_drop_rate=0.5, show_summary=True):
'''
input_shape: (height, width, channel)
initial_features: int, 初始特征图数量,每次下采样特征图数量加倍, unet3+原文中用的是64
kernel_size: int, 卷积核大小
class_num: int, 图像分割的类别数
norm_type: str, 标准化方式, 'batch' 或 'layer', unet3+使用的是BatchNormalization
double_features: bool, 在conv2d_block模块中是否在第二个卷积中将特征图数量翻倍,3dunet论文中提出该方法可以避免瓶颈问题,通常可以设为False
use_residual: bool, 编码器部分是否使用残差连接
down_pattern: str, 下采样方式, 'maxpooling' 或 'avgpooling' 或 'conv' 或 'normconv', unet3+使用的是MaxPooling
using_deep_supervision: bool, 是否使用全尺度深度监督
using_cgm: bool, 是否使用分类指导模块(CGM)
cgm_drop_rate: float, CGM模块中Dropout比率
show_summary: bool, 是否显示模型概况
'''
if class_num == 1:
last_layer_activation = 'sigmoid'
else:
last_layer_activation = 'softmax'
inputs = Input(input_shape)
xe1 = conv2d_block(input_tensor=inputs, filters=initial_features, kernel_size=kernel_size,
norm_type=norm_type, double_features=double_features, use_residual=use_residual)
xe1_pool = down_layer_2d(input_tensor=xe1, down_pattern=down_pattern, filters=initial_features)
xe2 = conv2d_block(input_tensor=xe1_pool, filters=initial_features * 2, kernel_size=kernel_size,
norm_type=norm_type, double_features=double_features, use_residual=use_residual)
xe2_pool = down_layer_2d(input_tensor=xe2, down_pattern=down_pattern, filters=initial_features * 2)
xe3 = conv2d_block(input_tensor=xe2_pool, filters=initial_features * 4, kernel_size=kernel_size,
norm_type=norm_type, double_features=double_features, use_residual=use_residual)
xe3_pool = down_layer_2d(input_tensor=xe3, down_pattern=down_pattern, filters=initial_features * 4)
xe4 = conv2d_block(input_tensor=xe3_pool, filters=initial_features * 8, kernel_size=kernel_size,
norm_type=norm_type, double_features=double_features, use_residual=use_residual)
xe4_pool = down_layer_2d(input_tensor=xe4, down_pattern=down_pattern, filters=initial_features * 8)
xe5 = conv2d_block(input_tensor=xe4_pool, filters=initial_features * 16, kernel_size=kernel_size,
norm_type=norm_type, double_features=double_features, use_residual=use_residual)
if using_cgm:
cgm = cgm_block(input_tensor = xe5 , class_num = class_num ,dropout_rate = cgm_drop_rate)
xd4_from_xe5 = UpSampling2D(size=(2,2), interpolation='bilinear')(xe5)
xd4_from_xe5 = conv_norm_act(input_tensor=xd4_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd4_from_xe4 = conv_norm_act(input_tensor=xe4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd4_from_xe3 = MaxPooling2D(pool_size = (2, 2))(xe3)
xd4_from_xe3 = conv_norm_act(input_tensor=xd4_from_xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd4_from_xe2 = MaxPooling2D(pool_size = (4, 4))(xe2)
xd4_from_xe2 = conv_norm_act(input_tensor=xd4_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd4_from_xe1 = MaxPooling2D(pool_size = (8, 8))(xe1)
xd4_from_xe1 = conv_norm_act(input_tensor=xd4_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd4 = aggregate(xd4_from_xe5, xd4_from_xe4, xd4_from_xe3, xd4_from_xe2, xd4_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)
xd3_from_xe5 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xe5)
xd3_from_xe5 = conv_norm_act(input_tensor=xd3_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd3_from_xd4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd4)
xd3_from_xd4 = conv_norm_act(input_tensor=xd3_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd3_from_xe3 = conv_norm_act(input_tensor=xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd3_from_xe2 = MaxPooling2D(pool_size = (2, 2))(xe2)
xd3_from_xe2 = conv_norm_act(input_tensor=xd3_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd3_from_xe1 = MaxPooling2D(pool_size = (4, 4))(xe1)
xd3_from_xe1 = conv_norm_act(input_tensor=xd3_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd3 = aggregate(xd3_from_xe5, xd3_from_xd4, xd3_from_xe3, xd3_from_xe2, xd3_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)
xd2_from_xe5 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xe5)
xd2_from_xe5 = conv_norm_act(input_tensor=xd2_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd2_from_xd4 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd4)
xd2_from_xd4 = conv_norm_act(input_tensor=xd2_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd2_from_xd3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd3)
xd2_from_xd3 = conv_norm_act(input_tensor=xd2_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd2_from_xe2 = conv_norm_act(input_tensor=xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd2_from_xe1 = MaxPooling2D(pool_size = (2, 2))(xe1)
xd2_from_xe1 = conv_norm_act(input_tensor=xd2_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd2 = aggregate(xd2_from_xe5, xd2_from_xd4, xd2_from_xd3, xd2_from_xe2, xd2_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)
xd1_from_xe5 = UpSampling2D(size=(16, 16), interpolation='bilinear')(xe5)
xd1_from_xe5 = conv_norm_act(input_tensor=xd1_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd1_from_xd4 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xd4)
xd1_from_xd4 = conv_norm_act(input_tensor=xd1_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd1_from_xd3 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd3)
xd1_from_xd3 = conv_norm_act(input_tensor=xd1_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd1_from_xd2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd2)
xd1_from_xd2 = conv_norm_act(input_tensor=xd1_from_xd2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd1_from_xe1 = conv_norm_act(input_tensor=xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
xd1 = aggregate(xd1_from_xe5, xd1_from_xd4, xd1_from_xd3, xd1_from_xd2, xd1_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)
if using_deep_supervision:
xd55 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xe5)
xd55 = UpSampling2D(size=(16, 16))(xd55)
xd55 = Activation(last_layer_activation, name='output_de5')(xd55)
xd44 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd4)
xd44 = UpSampling2D(size=(8, 8))(xd44)
xd44 = Activation(last_layer_activation, name='output_de4')(xd44)
xd33 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd3)
xd33 = UpSampling2D(size=(4, 4))(xd33)
xd33 = Activation(last_layer_activation, name='output_de3')(xd33)
xd22 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd2)
xd22 = UpSampling2D(size=(2, 2))(xd22)
xd22 = Activation(last_layer_activation, name='output_de2')(xd22)
xd11 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd1)
xd11 = Activation(last_layer_activation, name='output_de1')(xd11)
if using_cgm: outputs=[xd11, xd22, xd33, xd44, xd55, cgm]
else: outputs=[xd11, xd22, xd33, xd44, xd55]
else:
conv_output = Conv2D(class_num, 1, activation=last_layer_activation, name='output')(xd1)
if using_cgm: outputs=[conv_output, cgm]
else: outputs = conv_output
model = Model(inputs, outputs)
if show_summary: model.summary()
return model
如果以上代码都在同一个.py文件下,可以加上以下代码尝试构建网络:
if __name__ == '__main__':
model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
class_num=1, norm_type='batch', double_features=False,
use_residual=False, down_pattern='maxpooling',
using_deep_supervision=True, using_cgm=False, show_summary=True)
model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
class_num=1, norm_type='batch', double_features=False,
use_residual=False, down_pattern='maxpooling',
using_deep_supervision=True, using_cgm=True, show_summary=True)
model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
class_num=1, norm_type='batch', double_features=False,
use_residual=False, down_pattern='maxpooling',
using_deep_supervision=False, using_cgm=False, show_summary=True)
如果用到了预训练的主干网络,需要修改下编码器(En)部分。
感觉自己好菜,不知道能不能顺利be yeah,哎TAT