[深度学习从入门到女装]keras实战-Unet3d(BRAST2015)

本文采用Unet3d进行BRAST2015数据集的分割

BRAST2015读取

BRAST2015数据集为脑部肿瘤分割的,数据集分为两类,一类为HGG(高分级胶质瘤),另一类为LGG(低分级胶质瘤)

文件夹目录如图所示

[深度学习从入门到女装]keras实战-Unet3d(BRAST2015)_第1张图片

包含了MRI Flair T1 T1c T2四种模态,OT为ground Truth(0,1,2,3,4五种标签)

数据文件都为.mha文件,可以直接使用SimpleITK进行读取

def sitk_read(img_path):
    nda = sitk.ReadImage(img_path)
    nda = sitk.GetArrayFromImage(nda) #(155,240,240)
    nda = nda.transpose(1, 2, 0) #(240,240,155)
    return nda

直接调用SimpleITK.ReadImage得到Image对象,然后转成np,这里读出来的shape是(depth,height,width)

因为BRAST的数据集中只有train数据集有ground truth,val数据集是没有ground truth的,所以要在train数据集上分成train、val、test三部分用于网络,先读入所有train文件名用list存起来,然后用random随机打乱顺序,然后取8份为train,1份为val,1份为test,将三个数据集的文件名用txt存起来,以便于训练的时候直接读取

    def write_train_val_test_name_list(self):
        data_name_list = os.listdir(self.train_root_path + self.type + "\\")
        random.shuffle(data_name_list)
        length = len(data_name_list)
        n_train_file = int(length / 10 * 8)
        n_val_file = int(length / 10 * 1)
        train_name_list = data_name_list[0:n_train_file]
        val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]
        test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]
        self.write_name_list(train_name_list, "train_name_list.txt")
        self.write_name_list(val_name_list, "val_name_list.txt")
        self.write_name_list(test_name_list, "test_name_list.txt")

    def write_name_list(self, name_list, file_name):
        f = open(self.train_root_path + file_name, 'w')
        for i in range(len(name_list)):
            f.write(name_list[i] + "\n")
        f.close()

 

完整代码如下:

def make_one_hot_3d(x, n):
    one_hot = np.zeros([x.shape[0], x.shape[1], x.shape[2], n])
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for v in range(x.shape[2]):
                one_hot[i, j, v, int(x[i, j, v])] = 1
    return one_hot


class brast_reader:
    def __init__(self, train_batch_size, val_batch_size, test_batch_size, type='HGG'):
        self.train_root_path = "D:\\pyproject\\data\\BRATS2015\\BRATS2015_Training\\BRATS2015_Training\\"
        self.type = type

        self.train_name_list = self.load_file_name_list(self.train_root_path + "train_name_list.txt")
        self.val_name_list = self.load_file_name_list(self.train_root_path + "val_name_list.txt")
        self.test_name_list = self.load_file_name_list(self.train_root_path + "test_name_list.txt")

        self.n_train_file = len(self.train_name_list)
        self.n_val_file = len(self.val_name_list)
        self.n_test_file = len(self.test_name_list)

        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.test_batch_size = test_batch_size

        self.n_train_steps_per_epoch = self.n_train_file // self.train_batch_size
        self.n_val_steps_per_epoch = self.n_val_file // self.val_batch_size

        self.img_height = 240
        self.img_width = 240
        self.img_depth = 160
        self.n_labels = 5

        self.train_batch_index = 0
        self.val_batch_index = 0

    def load_file_name_list(self, file_path):
        file_name_list = []
        with open(file_path, 'r') as file_to_read:
            while True:
                lines = file_to_read.readline().strip()  # 整行读取数据
                if not lines:
                    break
                    pass
                file_name_list.append(lines)
                pass
        return file_name_list

    def write_train_val_test_name_list(self):
        data_name_list = os.listdir(self.train_root_path + self.type + "\\")
        random.shuffle(data_name_list)
        length = len(data_name_list)
        n_train_file = int(length / 10 * 8)
        n_val_file = int(length / 10 * 1)
        train_name_list = data_name_list[0:n_train_file]
        val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]
        test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]
        self.write_name_list(train_name_list, "train_name_list.txt")
        self.write_name_list(val_name_list, "val_name_list.txt")
        self.write_name_list(test_name_list, "test_name_list.txt")

    def write_name_list(self, name_list, file_name):
        f = open(self.train_root_path + file_name, 'w')
        for i in range(len(name_list)):
            f.write(name_list[i] + "\n")
        f.close()

    def next_train_batch_2d(self):
        if self.train_batch_index >= self.n_train_file:
            self.train_batch_index = 0

        data_path = self.train_root_path + self.type + '\\' + self.train_name_list[self.train_batch_index]

        # flair, t1, t1c, t2, ot=self.get_np_data(data_path)
        t1, ot = self.get_np_data_2d(data_path)
        train_imgs=t1[:,:,:,np.newaxis] #(155,240,240,1)
        train_labels=make_one_hot_3d(ot,self.n_labels) #(155,240,240,5)

        self.train_batch_index+=1
        return train_imgs,train_labels

    def next_val_batch_2d(self):
        if self.val_batch_index >= self.n_val_file:
            self.val_batch_index = 0

        data_path = self.train_root_path + self.type + '\\' + self.val_name_list[self.val_batch_index]


        # flair, t1, t1c, t2, ot=self.get_np_data(data_path)
        t1, ot = self.get_np_data_2d(data_path)

        val_imgs=t1[:,:,:,np.newaxis] #(155,240,240,1)
        val_labels=make_one_hot_3d(ot,self.n_labels)
        self.val_batch_index += 1

        return val_imgs, val_labels




    def next_train_batch_3d(self):
        train_imgs = np.zeros((self.train_batch_size, self.img_height, self.img_width, self.img_depth, 1))
        train_labels = np.zeros([self.train_batch_size, self.img_height, self.img_width, self.img_depth, self.n_labels])
        if self.train_batch_index >= self.n_train_steps_per_epoch:
            self.train_batch_index = 0
        for i in range(self.train_batch_size):
            data_path = self.train_root_path + self.type + '\\' + self.train_name_list[
                self.train_batch_size * self.train_batch_index + i]

            # flair, t1, t1c, t2, ot=self.get_np_data(data_path)
            t1, ot = self.get_np_data_3d(data_path)
            # flair=flair[:,:,:,np.newaxis]
            t1 = t1[:, :, :, np.newaxis]
            # t1c = t1c[:, :, :, np.newaxis]
            # t2 = t2[:, :, :, np.newaxis]
            train_imgs[i] = t1
            one_hot = make_one_hot_3d(ot, self.n_labels)
            train_labels[i] = one_hot

        self.train_batch_index += 1

        return train_imgs, train_labels

    def next_val_batch_3d(self):
        val_imgs = np.zeros((self.train_batch_size, self.img_height, self.img_width, self.img_depth, 1))
        val_labels = np.zeros([self.train_batch_size, self.img_height, self.img_width, self.img_depth, self.n_labels])
        if self.val_batch_index >= self.n_val_steps_per_epoch:
            self.val_batch_index = 0
        for i in range(self.val_batch_size):
            data_path = self.train_root_path + self.type + '\\' + self.val_name_list[
                self.val_batch_size * self.val_batch_index + i]

            # flair, t1, t1c, t2, ot=self.get_np_data(data_path)
            t1, ot = self.get_np_data_3d(data_path)
            # flair=flair[:,:,:,np.newaxis]
            t1 = t1[:, :, :, np.newaxis]
            # t1c = t1c[:, :, :, np.newaxis]
            # t2 = t2[:, :, :, np.newaxis]
            val_imgs[i] = t1
            one_hot = make_one_hot_3d(ot, self.n_labels)
            val_labels[i] = one_hot

        self.val_batch_index += 1

        return val_imgs, val_labels

    def get_np_data_3d(self, data_path):
        for i in glob.glob(os.path.join(data_path, 'VSD.Brain.XX.O.MR_T1.*\\VSD.Brain.XX.O.MR_T1.*.mha')):
            t1_file_path = i
        for i in glob.glob(os.path.join(data_path, 'VSD.Brain_3more.XX.*\\VSD.Brain_3more.XX.*.mha')):
            ot_file_path = i

        '''
        for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_Flair.*\\VSD.Brain.XX.O.MR_Flair.*.mha')):
            flair_file_path=i

        

        for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T1c.*\\VSD.Brain.XX.O.MR_T1c.*.mha')):
            t1c_file_path=i

        for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T2.*\\VSD.Brain.XX.O.MR_T2.*.mha')):
            t2_file_path=i

        

        flair=sitk_read(flair_file_path)
        
        t1c=sitk_read(t1c_file_path)
        t2=sitk_read(t2_file_path)
        
        '''
        t1 = sitk_read(t1_file_path)
        ot = sitk_read(ot_file_path)
        return t1, ot
        # return flair,t1,t1c,t2,ot

    def get_np_data_2d(self, data_path):
        for i in glob.glob(os.path.join(data_path, 'VSD.Brain.XX.O.MR_T1.*\\VSD.Brain.XX.O.MR_T1.*.mha')):
            t1_file_path = i
        for i in glob.glob(os.path.join(data_path, 'VSD.Brain_3more.XX.*\\VSD.Brain_3more.XX.*.mha')):
            ot_file_path = i

        '''
        for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_Flair.*\\VSD.Brain.XX.O.MR_Flair.*.mha')):
            flair_file_path=i



        for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T1c.*\\VSD.Brain.XX.O.MR_T1c.*.mha')):
            t1c_file_path=i

        for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T2.*\\VSD.Brain.XX.O.MR_T2.*.mha')):
            t2_file_path=i



        flair=sitk_read(flair_file_path)

        t1c=sitk_read(t1c_file_path)
        t2=sitk_read(t2_file_path)

        '''
        t1 = sitk_read_row(t1_file_path)
        ot = sitk_read_row(ot_file_path)
        return t1, ot
        # return flair,t1,t1c,t2,ot

 

U-Net3D

 

import keras.backend as K
from keras.engine import Input, Model
import keras
from keras.optimizers import Adam
from keras.layers import BatchNormalization, Activation, Conv3D, Conv3DTranspose, MaxPooling3D
import metrics as m
from keras.layers.core import Lambda
import numpy as np


def up_and_concate_3d(down_layer, layer):
    in_channel = down_layer.get_shape().as_list()[4]
    out_channel = in_channel // 2
    up = Conv3DTranspose(out_channel, [2, 2, 2], strides=[2, 2, 2], padding='valid')(down_layer)
    print("--------------")
    print(str(up.get_shape()))

    print(str(layer.get_shape()))
    print("--------------")
    my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=4))


    concate = my_concat([up, layer])
    # must use lambda
    # concate=K.concatenate([up, layer], 3)
    return concate


def attention_block_3d(x, g, inter_channel):
    '''

    :param x: x input from down_sampling same layer output x(?,x_height,x_width,x_depth,x_channel)
    :param g: gate input from up_sampling layer last output g(?,g_height,g_width,g_depth,g_channel)
    g_height,g_width,g_depth=x_height/2,x_width/2,x_depth/2
    :return:
    '''
    # theta_x(?,g_height,g_width,g_depth,inter_channel)
    theta_x = Conv3D(inter_channel, [2, 2, 2], strides=[2, 2, 2])(x)

    # phi_g(?,g_height,g_width,g_depth,inter_channel)
    phi_g = Conv3D(inter_channel, [1, 1, 1], strides=[1, 1, 1])(g)

    # f(?,g_height,g_width,g_depth,inter_channel)
    f = Activation('relu')(keras.layers.add([theta_x, phi_g]))

    # psi_f(?,g_height,g_width,g_depth,1)
    psi_f = Conv3D(1, [1, 1, 1], strides=[1, 1, 1])(f)

    # sigm_psi_f(?,g_height,g_width,g_depth)
    sigm_psi_f = Activation('sigmoid')(psi_f)

    # rate(?,x_height,x_width,x_depth)
    rate = UpSampling3D(size=[2, 2, 2])(sigm_psi_f)

    # att_x(?,x_height,x_width,x_depth,x_channel)
    att_x = keras.layers.multiply([x, rate])

    return att_x


def unet_model_3d(input_shape, n_labels, batch_normalization=False, initial_learning_rate=0.00001,
                  metrics=m.dice_coefficient):
    """
    input_shape:without batch_size,(img_height,img_width,img_depth)
    metrics:
    """

    inputs = Input(input_shape)

    down_layer = []

    layer = inputs

    # down_layer_1
    layer = res_block_v2_3d(layer, 64, batch_normalization=batch_normalization)
    down_layer.append(layer)
    layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)

    print(str(layer.get_shape()))

    # down_layer_2
    layer = res_block_v2_3d(layer, 128, batch_normalization=batch_normalization)
    down_layer.append(layer)
    layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)

    print(str(layer.get_shape()))

    # down_layer_3
    layer = res_block_v2_3d(layer, 256, batch_normalization=batch_normalization)
    down_layer.append(layer)
    layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)

    print(str(layer.get_shape()))

    # down_layer_4
    layer = res_block_v2_3d(layer, 512, batch_normalization=batch_normalization)
    down_layer.append(layer)
    layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)

    print(str(layer.get_shape()))

    # bottle_layer
    layer = res_block_v2_3d(layer, 1024, batch_normalization=batch_normalization)
    print(str(layer.get_shape()))

    # up_layer_4
    layer = up_and_concate_3d(layer, down_layer[3])
    layer = res_block_v2_3d(layer, 512, batch_normalization=batch_normalization)
    print(str(layer.get_shape()))

    # up_layer_3
    layer = up_and_concate_3d(layer, down_layer[2])
    layer = res_block_v2_3d(layer, 256, batch_normalization=batch_normalization)
    print(str(layer.get_shape()))

    # up_layer_2
    layer = up_and_concate_3d(layer, down_layer[1])
    layer = res_block_v2_3d(layer, 128, batch_normalization=batch_normalization)
    print(str(layer.get_shape()))

    # up_layer_1
    layer = up_and_concate_3d(layer, down_layer[0])
    layer = res_block_v2_3d(layer, 64, batch_normalization=batch_normalization)
    print(str(layer.get_shape()))

    # score_layer
    layer = Conv3D(n_labels, [1, 1, 1], strides=[1, 1, 1])(layer)
    print(str(layer.get_shape()))

    # softmax
    layer = Activation('softmax')(layer)
    print(str(layer.get_shape()))

    outputs = layer

    model = Model(inputs=inputs, outputs=outputs)

    metrics = [metrics]

    model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)

    return model


def res_block_v2_3d(input_layer, out_n_filters, batch_normalization=False, kernel_size=[3, 3, 3], stride=[1, 1, 1],
                    padding='same'):
    input_n_filters = input_layer.get_shape().as_list()[3]
    print(str(input_layer.get_shape()))
    layer = input_layer

    for i in range(2):
        if batch_normalization:
            layer = BatchNormalization()(layer)
        layer = Activation('relu')(layer)
        layer = Conv3D(out_n_filters, kernel_size, strides=stride, padding=padding)(layer)

    if out_n_filters != input_n_filters:
        skip_layer = Conv3D(out_n_filters, [1, 1, 1], strides=stride, padding=padding)(input_layer)
    else:
        skip_layer = input_layer

    out_layer = keras.layers.add([layer, skip_layer])

    return out_layer

和之前那篇使用u-net2d进行voc分割的网络结构没有什么区别,只是将卷积、pooling、concate操作都改成了3维操作

 

在跑的过程中发现了一些bug,最新代码可在git上找到,以最近代码为准:

https://github.com/panxiaobai/brats_keras

你可能感兴趣的:(深度学习)