CNN + RNN(ConvLSTM2D)图像分割分类

CNN + RNN(ConvLSTM2D)

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
第一次看到这个思想是在2018MICCAI会议论文,CFCM: Segmentation via Coarse to Fine Context Memory,做医学图像分割.
阅读数只有50但已收到一部分人邮箱Call,正好这段时间把ConvLSTM2D和BiConvLSTM2D都测试了下,趁着年前最后一天工作时间,将心得完善了下. 喜欢关注下,后面会写学习到的新东西,春节愉快! (2019-1-30)

原文:
在网上找了很多版本,都没有自己想要的
在一个普通的U-net加Res上修改的,使用keras框架
所以自己填坑踩坑再填坑,直接上代码和网络图,有问题讨论随时Call
训练网络主要用来做图像分割,加入LSTM为了让网络学习到长期依赖的信息.

U-net就不多赘述了,搞计算机视觉的应该都有接触,但是在CNN中加入RNN提取图像特征的确实不多,LSTM(长短期记忆机制)属于RNN中的衍生品,之后还有GRU(门控单元)是简化了的LSTM.
说白了就是在提取图像信息特征的时候类似提取序列特征思想一样提取图像的上下文信息(上文指单向LSTM,上下文指双向LSTM,双向LSTM也测试过了,正好今天有一个港外小伙伴私邮想让测试下,这里双向LSTM通过使用Bidirectional包装器).具体想了解LSTM机制的可以看下面这篇blogLSTM原理及实现

CNN可以提取空间特性,LSTM可以提取时间特性,ConvLSTM可以同时利用时空特性.
ConvLSTM核心本质还是和LSTM一样,将上一层的输出作下一层的输入。不同的地方在于加上卷积操作之后,为不仅能够得到时序关系,还能够像卷积层一样提取特征,提取空间特征。这样就能够得到时空特征。并且将状态与状态之间的切换也换成了卷积计算。
keras的ConvLSTM2D层,也是一个LSTM网络,但它的输入变换和循环变换是通过卷积实现的,ConvLSTM2D的输入和输出形状如下:
输入形状:
5D tensor(samples,time,channels,row,cols)
输出形状可选:
5D tensor(samples,time,output_row,output_col,filters) 返回每个序列的结果
4D tensor(samples,output_row,output_col,filters)只返回最后一个序列的结果
其中,time代表每一个输入样本的图像序列所具有的图像帧数,这样就用到了TimeDistributed包装器.
能明白我描述的东西就够了,理论性的东西感觉没有必要理解得太深入,知道在干什么就好.

后面直接实操,还需要掌握的有上面提到的两个包装器:TimeDistributed层和Bidirectional层(Keras自带)
1.使用TimeDistributed包装器,将一个图层应用于输入的每个时间片(就是把time维每一序列单独做卷积操作提取特征)
keras.layers.TimeDistributed(layer)
2.使用Bidirectional双向封装器,将单向LSTM扩展,前向传播的时候增加学习参数,利用到后面序列(未来)的信息提取特征,使用时包装在想使用的LSTM层就好
keras.layers.Bidirectional(layer, merge_mode=‘concat’, weights=None)

测试结果:
实测效果还不错,在推广使用,有效避免2D断层,3D上下界假阳假阴问题。
前处理的话针对不同领域分割图有不同的前处理方法,数据增强时使用了平移/旋转/噪点/场强增强等方法
个人建议序列值为10左右做尝试,除此之外还需考虑算力和效率之间的平衡.
(1)输入序列增加后单双向LSTM最优dice值均升高
(2)双向LSTM较单向收敛更稳定更快

#-*- coding:utf-8 -*-
"""
@Author   :Alex 
@Datetime :19-1-11 下午2:42
@contact: [email protected]
@File name:segmentation-minify/U_net_convlstm2d
@Software : PyCharm
@Desc: CNN+ConvLSTM
@==============================@
@       ___   __    _  __      @
@      / _ | / /__ | |/_/      @
@     / __ |/ / -_)>  <        @
@    /_/ |_/_/\__/_/|_|        @
@                       常敦瑞  @
@==============================@
"""

from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.utils.vis_utils import plot_model
from keras.layers.convolutional_recurrent import ConvLSTM2D


def get_unet(pretrained_weights=None, input_size=(None, 160, 240, 1)):
	inputs = Input(input_size)
	conv1 = TimeDistributed(Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(inputs)
	conv1 = TimeDistributed(Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(conv1)
	pool1 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv1)
	conv2 = TimeDistributed(Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(pool1)
	conv2 = TimeDistributed(Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(conv2)
	pool2 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv2)
	conv3 = TimeDistributed(Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(pool2)
	conv3 = TimeDistributed(Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(conv3)
	# pool3 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv3)
	# conv4 = TimeDistributed(Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(pool3)
	# conv4 = TimeDistributed(Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(conv4)
	drop4 = TimeDistributed(Dropout(0.5))(conv3)
	pool4 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(drop4)
	
	conv5 = TimeDistributed(Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(pool4)
	conv5 = TimeDistributed(Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal'))(conv5)
	drop5 = TimeDistributed(Dropout(0.5))(conv5)
	
	up6 = ConvLSTM2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal', return_sequences=True)(
		TimeDistributed(UpSampling2D(size=(2, 2)))(drop5))
	merge6 = concatenate([drop4, up6], axis=4)
	# conv6 = ConvLSTM2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal', return_sequences=True)(merge6)
	# conv6 = ConvLSTM2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal', return_sequences=True)(conv6)
	
	# up7 = ConvLSTM2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal', return_sequences=True)(
	# 	TimeDistributed(UpSampling2D(size=(2, 2)))(conv6))
	merge7 = concatenate([conv3, up6], axis=4)
	conv7 = ConvLSTM2D(256, 3, padding='same', return_sequences=True)(merge7)
	conv7 = ConvLSTM2D(256, 3, padding='same', return_sequences=True)(conv7)
	
	up8 = ConvLSTM2D(128, 2, padding='same',return_sequences=True)(
		TimeDistributed(UpSampling2D(size=(2, 2)))(conv7))
	merge8 = concatenate([conv2, up8], axis=4)
	conv8 = ConvLSTM2D(128, 3, padding='same', return_sequences=True)(merge8)
	conv8 = ConvLSTM2D(128, 3, padding='same', return_sequences=True)(conv8)
	
	up9 = ConvLSTM2D(64, 2, padding='same', return_sequences=True)(
		TimeDistributed(UpSampling2D(size=(2, 2)))(conv8))
	merge9 = concatenate([conv1, up9], axis=4)
	conv9 = ConvLSTM2D(64, 3, padding='same', return_sequences=True)(merge9)
	conv9 = ConvLSTM2D(64, 3, padding='same', return_sequences=True)(conv9)
	conv9 = TimeDistributed(Conv2D(2, 3, activation='relu', padding='same'))(conv9)
	# conv9 = ConvLSTM2D(2, 3, padding='same', return_sequences=True)(conv9)
	# conv10 = ConvLSTM2D(3, 1, activation='softmax', return_sequences=True)(conv9)
	conv10 = TimeDistributed(Conv2D(2, 1,activation='softmax', padding='same'))(conv9)
	
	model = Model(input=inputs, output=conv10)
	
	model.compile(optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])
	
	plot_model(model, to_file='MRI_brain_seg_UNet3D.png', show_shapes=True)
	model.summary()
	
	if (pretrained_weights):
		model.load_weights(pretrained_weights)
	
	return model

Plot To_file 网络图

CNN + RNN(ConvLSTM2D)图像分割分类_第1张图片
版权声明:本文为博主原创文章,未经博主允许不得转载

你可能感兴趣的:(图像处理,深度学习网络)