本文采用Unet3d进行BRAST2015数据集的分割
BRAST2015数据集为脑部肿瘤分割的,数据集分为两类,一类为HGG(高分级胶质瘤),另一类为LGG(低分级胶质瘤)
文件夹目录如图所示
包含了MRI Flair T1 T1c T2四种模态,OT为ground Truth(0,1,2,3,4五种标签)
数据文件都为.mha文件,可以直接使用SimpleITK进行读取
def sitk_read(img_path):
nda = sitk.ReadImage(img_path)
nda = sitk.GetArrayFromImage(nda) #(155,240,240)
nda = nda.transpose(1, 2, 0) #(240,240,155)
return nda
直接调用SimpleITK.ReadImage得到Image对象,然后转成np,这里读出来的shape是(depth,height,width)
因为BRAST的数据集中只有train数据集有ground truth,val数据集是没有ground truth的,所以要在train数据集上分成train、val、test三部分用于网络,先读入所有train文件名用list存起来,然后用random随机打乱顺序,然后取8份为train,1份为val,1份为test,将三个数据集的文件名用txt存起来,以便于训练的时候直接读取
def write_train_val_test_name_list(self):
data_name_list = os.listdir(self.train_root_path + self.type + "\\")
random.shuffle(data_name_list)
length = len(data_name_list)
n_train_file = int(length / 10 * 8)
n_val_file = int(length / 10 * 1)
train_name_list = data_name_list[0:n_train_file]
val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]
test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]
self.write_name_list(train_name_list, "train_name_list.txt")
self.write_name_list(val_name_list, "val_name_list.txt")
self.write_name_list(test_name_list, "test_name_list.txt")
def write_name_list(self, name_list, file_name):
f = open(self.train_root_path + file_name, 'w')
for i in range(len(name_list)):
f.write(name_list[i] + "\n")
f.close()
完整代码如下:
def make_one_hot_3d(x, n):
one_hot = np.zeros([x.shape[0], x.shape[1], x.shape[2], n])
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for v in range(x.shape[2]):
one_hot[i, j, v, int(x[i, j, v])] = 1
return one_hot
class brast_reader:
def __init__(self, train_batch_size, val_batch_size, test_batch_size, type='HGG'):
self.train_root_path = "D:\\pyproject\\data\\BRATS2015\\BRATS2015_Training\\BRATS2015_Training\\"
self.type = type
self.train_name_list = self.load_file_name_list(self.train_root_path + "train_name_list.txt")
self.val_name_list = self.load_file_name_list(self.train_root_path + "val_name_list.txt")
self.test_name_list = self.load_file_name_list(self.train_root_path + "test_name_list.txt")
self.n_train_file = len(self.train_name_list)
self.n_val_file = len(self.val_name_list)
self.n_test_file = len(self.test_name_list)
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.test_batch_size = test_batch_size
self.n_train_steps_per_epoch = self.n_train_file // self.train_batch_size
self.n_val_steps_per_epoch = self.n_val_file // self.val_batch_size
self.img_height = 240
self.img_width = 240
self.img_depth = 160
self.n_labels = 5
self.train_batch_index = 0
self.val_batch_index = 0
def load_file_name_list(self, file_path):
file_name_list = []
with open(file_path, 'r') as file_to_read:
while True:
lines = file_to_read.readline().strip() # 整行读取数据
if not lines:
break
pass
file_name_list.append(lines)
pass
return file_name_list
def write_train_val_test_name_list(self):
data_name_list = os.listdir(self.train_root_path + self.type + "\\")
random.shuffle(data_name_list)
length = len(data_name_list)
n_train_file = int(length / 10 * 8)
n_val_file = int(length / 10 * 1)
train_name_list = data_name_list[0:n_train_file]
val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]
test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]
self.write_name_list(train_name_list, "train_name_list.txt")
self.write_name_list(val_name_list, "val_name_list.txt")
self.write_name_list(test_name_list, "test_name_list.txt")
def write_name_list(self, name_list, file_name):
f = open(self.train_root_path + file_name, 'w')
for i in range(len(name_list)):
f.write(name_list[i] + "\n")
f.close()
def next_train_batch_2d(self):
if self.train_batch_index >= self.n_train_file:
self.train_batch_index = 0
data_path = self.train_root_path + self.type + '\\' + self.train_name_list[self.train_batch_index]
# flair, t1, t1c, t2, ot=self.get_np_data(data_path)
t1, ot = self.get_np_data_2d(data_path)
train_imgs=t1[:,:,:,np.newaxis] #(155,240,240,1)
train_labels=make_one_hot_3d(ot,self.n_labels) #(155,240,240,5)
self.train_batch_index+=1
return train_imgs,train_labels
def next_val_batch_2d(self):
if self.val_batch_index >= self.n_val_file:
self.val_batch_index = 0
data_path = self.train_root_path + self.type + '\\' + self.val_name_list[self.val_batch_index]
# flair, t1, t1c, t2, ot=self.get_np_data(data_path)
t1, ot = self.get_np_data_2d(data_path)
val_imgs=t1[:,:,:,np.newaxis] #(155,240,240,1)
val_labels=make_one_hot_3d(ot,self.n_labels)
self.val_batch_index += 1
return val_imgs, val_labels
def next_train_batch_3d(self):
train_imgs = np.zeros((self.train_batch_size, self.img_height, self.img_width, self.img_depth, 1))
train_labels = np.zeros([self.train_batch_size, self.img_height, self.img_width, self.img_depth, self.n_labels])
if self.train_batch_index >= self.n_train_steps_per_epoch:
self.train_batch_index = 0
for i in range(self.train_batch_size):
data_path = self.train_root_path + self.type + '\\' + self.train_name_list[
self.train_batch_size * self.train_batch_index + i]
# flair, t1, t1c, t2, ot=self.get_np_data(data_path)
t1, ot = self.get_np_data_3d(data_path)
# flair=flair[:,:,:,np.newaxis]
t1 = t1[:, :, :, np.newaxis]
# t1c = t1c[:, :, :, np.newaxis]
# t2 = t2[:, :, :, np.newaxis]
train_imgs[i] = t1
one_hot = make_one_hot_3d(ot, self.n_labels)
train_labels[i] = one_hot
self.train_batch_index += 1
return train_imgs, train_labels
def next_val_batch_3d(self):
val_imgs = np.zeros((self.train_batch_size, self.img_height, self.img_width, self.img_depth, 1))
val_labels = np.zeros([self.train_batch_size, self.img_height, self.img_width, self.img_depth, self.n_labels])
if self.val_batch_index >= self.n_val_steps_per_epoch:
self.val_batch_index = 0
for i in range(self.val_batch_size):
data_path = self.train_root_path + self.type + '\\' + self.val_name_list[
self.val_batch_size * self.val_batch_index + i]
# flair, t1, t1c, t2, ot=self.get_np_data(data_path)
t1, ot = self.get_np_data_3d(data_path)
# flair=flair[:,:,:,np.newaxis]
t1 = t1[:, :, :, np.newaxis]
# t1c = t1c[:, :, :, np.newaxis]
# t2 = t2[:, :, :, np.newaxis]
val_imgs[i] = t1
one_hot = make_one_hot_3d(ot, self.n_labels)
val_labels[i] = one_hot
self.val_batch_index += 1
return val_imgs, val_labels
def get_np_data_3d(self, data_path):
for i in glob.glob(os.path.join(data_path, 'VSD.Brain.XX.O.MR_T1.*\\VSD.Brain.XX.O.MR_T1.*.mha')):
t1_file_path = i
for i in glob.glob(os.path.join(data_path, 'VSD.Brain_3more.XX.*\\VSD.Brain_3more.XX.*.mha')):
ot_file_path = i
'''
for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_Flair.*\\VSD.Brain.XX.O.MR_Flair.*.mha')):
flair_file_path=i
for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T1c.*\\VSD.Brain.XX.O.MR_T1c.*.mha')):
t1c_file_path=i
for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T2.*\\VSD.Brain.XX.O.MR_T2.*.mha')):
t2_file_path=i
flair=sitk_read(flair_file_path)
t1c=sitk_read(t1c_file_path)
t2=sitk_read(t2_file_path)
'''
t1 = sitk_read(t1_file_path)
ot = sitk_read(ot_file_path)
return t1, ot
# return flair,t1,t1c,t2,ot
def get_np_data_2d(self, data_path):
for i in glob.glob(os.path.join(data_path, 'VSD.Brain.XX.O.MR_T1.*\\VSD.Brain.XX.O.MR_T1.*.mha')):
t1_file_path = i
for i in glob.glob(os.path.join(data_path, 'VSD.Brain_3more.XX.*\\VSD.Brain_3more.XX.*.mha')):
ot_file_path = i
'''
for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_Flair.*\\VSD.Brain.XX.O.MR_Flair.*.mha')):
flair_file_path=i
for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T1c.*\\VSD.Brain.XX.O.MR_T1c.*.mha')):
t1c_file_path=i
for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T2.*\\VSD.Brain.XX.O.MR_T2.*.mha')):
t2_file_path=i
flair=sitk_read(flair_file_path)
t1c=sitk_read(t1c_file_path)
t2=sitk_read(t2_file_path)
'''
t1 = sitk_read_row(t1_file_path)
ot = sitk_read_row(ot_file_path)
return t1, ot
# return flair,t1,t1c,t2,ot
import keras.backend as K
from keras.engine import Input, Model
import keras
from keras.optimizers import Adam
from keras.layers import BatchNormalization, Activation, Conv3D, Conv3DTranspose, MaxPooling3D
import metrics as m
from keras.layers.core import Lambda
import numpy as np
def up_and_concate_3d(down_layer, layer):
in_channel = down_layer.get_shape().as_list()[4]
out_channel = in_channel // 2
up = Conv3DTranspose(out_channel, [2, 2, 2], strides=[2, 2, 2], padding='valid')(down_layer)
print("--------------")
print(str(up.get_shape()))
print(str(layer.get_shape()))
print("--------------")
my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=4))
concate = my_concat([up, layer])
# must use lambda
# concate=K.concatenate([up, layer], 3)
return concate
def attention_block_3d(x, g, inter_channel):
'''
:param x: x input from down_sampling same layer output x(?,x_height,x_width,x_depth,x_channel)
:param g: gate input from up_sampling layer last output g(?,g_height,g_width,g_depth,g_channel)
g_height,g_width,g_depth=x_height/2,x_width/2,x_depth/2
:return:
'''
# theta_x(?,g_height,g_width,g_depth,inter_channel)
theta_x = Conv3D(inter_channel, [2, 2, 2], strides=[2, 2, 2])(x)
# phi_g(?,g_height,g_width,g_depth,inter_channel)
phi_g = Conv3D(inter_channel, [1, 1, 1], strides=[1, 1, 1])(g)
# f(?,g_height,g_width,g_depth,inter_channel)
f = Activation('relu')(keras.layers.add([theta_x, phi_g]))
# psi_f(?,g_height,g_width,g_depth,1)
psi_f = Conv3D(1, [1, 1, 1], strides=[1, 1, 1])(f)
# sigm_psi_f(?,g_height,g_width,g_depth)
sigm_psi_f = Activation('sigmoid')(psi_f)
# rate(?,x_height,x_width,x_depth)
rate = UpSampling3D(size=[2, 2, 2])(sigm_psi_f)
# att_x(?,x_height,x_width,x_depth,x_channel)
att_x = keras.layers.multiply([x, rate])
return att_x
def unet_model_3d(input_shape, n_labels, batch_normalization=False, initial_learning_rate=0.00001,
metrics=m.dice_coefficient):
"""
input_shape:without batch_size,(img_height,img_width,img_depth)
metrics:
"""
inputs = Input(input_shape)
down_layer = []
layer = inputs
# down_layer_1
layer = res_block_v2_3d(layer, 64, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)
print(str(layer.get_shape()))
# down_layer_2
layer = res_block_v2_3d(layer, 128, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)
print(str(layer.get_shape()))
# down_layer_3
layer = res_block_v2_3d(layer, 256, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)
print(str(layer.get_shape()))
# down_layer_4
layer = res_block_v2_3d(layer, 512, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)
print(str(layer.get_shape()))
# bottle_layer
layer = res_block_v2_3d(layer, 1024, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_4
layer = up_and_concate_3d(layer, down_layer[3])
layer = res_block_v2_3d(layer, 512, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_3
layer = up_and_concate_3d(layer, down_layer[2])
layer = res_block_v2_3d(layer, 256, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_2
layer = up_and_concate_3d(layer, down_layer[1])
layer = res_block_v2_3d(layer, 128, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_1
layer = up_and_concate_3d(layer, down_layer[0])
layer = res_block_v2_3d(layer, 64, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# score_layer
layer = Conv3D(n_labels, [1, 1, 1], strides=[1, 1, 1])(layer)
print(str(layer.get_shape()))
# softmax
layer = Activation('softmax')(layer)
print(str(layer.get_shape()))
outputs = layer
model = Model(inputs=inputs, outputs=outputs)
metrics = [metrics]
model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)
return model
def res_block_v2_3d(input_layer, out_n_filters, batch_normalization=False, kernel_size=[3, 3, 3], stride=[1, 1, 1],
padding='same'):
input_n_filters = input_layer.get_shape().as_list()[3]
print(str(input_layer.get_shape()))
layer = input_layer
for i in range(2):
if batch_normalization:
layer = BatchNormalization()(layer)
layer = Activation('relu')(layer)
layer = Conv3D(out_n_filters, kernel_size, strides=stride, padding=padding)(layer)
if out_n_filters != input_n_filters:
skip_layer = Conv3D(out_n_filters, [1, 1, 1], strides=stride, padding=padding)(input_layer)
else:
skip_layer = input_layer
out_layer = keras.layers.add([layer, skip_layer])
return out_layer
和之前那篇使用u-net2d进行voc分割的网络结构没有什么区别,只是将卷积、pooling、concate操作都改成了3维操作
在跑的过程中发现了一些bug,最新代码可在git上找到,以最近代码为准:
https://github.com/panxiaobai/brats_keras