UNet和FCN实现医疗图像分割

记录一下自己实现的过程,最近毕业设计涉及到医疗图像分割的问题,查阅相关资料后准备从两个分割网络入手,UNET和FCN,关于这两个网络具体的结构请参考大佬的论文
《Fully Convolutional Networks for Semantic Segmentation》
《U-Net: Convolutional Networks for Biomedical Image Segmentation》
主要记录自己如何一步步的完成自己的工作。

一.前期准备

1.基础环境:
本文使用pycharm+keras
主要用到的packa:keras,matplotlib,re,pydicom
具体的安装教程请自行百度
2.数据集,使用的为开放的cardiac-segmentation-master数据集,通过以下代码转换成为png图片以便读取进入网络。
dataconvert.py`

#!/usr/bin/env python2.7
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import pydicom, cv2, re
import os, fnmatch, sys
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
# from itertools import izip
from model import *
# from fcn_model import fcn_model
from helpers import center_crop, lr_poly_decay, get_SAX_SERIES
import matplotlib.pyplot as plt
from keras.models import Model

seed = 1234
np.random.seed(seed)

SAX_SERIES = get_SAX_SERIES()
SUNNYBROOK_ROOT_PATH = '.\\Sunnybrook_data'

TRAIN_CONTOUR_PATH = os.path.join(SUNNYBROOK_ROOT_PATH,
                                  'Sunnybrook Cardiac MR Database ContoursPart3',
                                  'TrainingDataContours')
TRAIN_IMG_PATH = os.path.join(SUNNYBROOK_ROOT_PATH,
                              'challenge_training')


def shrink_case(case):
    toks = case.split('-')

    def shrink_if_number(x):
        try:
            cvt = int(x)
            return str(cvt)
        except ValueError:
            return x

    return '-'.join([shrink_if_number(t) for t in toks])


class Contour(object):
    def __init__(self, ctr_path):
        ctr_path = ctr_path.replace("\\", "/")
        self.ctr_path = ctr_path
        match = re.search(r'/([^/]*)/contours-manual/IRCCI-expert/IM-0001-(\d{4})-.*', ctr_path)
        self.case = shrink_case(match.group(1))
        self.img_no = int(match.group(2))

    def __str__(self):
        return '' % (self.case, self.img_no)

    __repr__ = __str__


def read_contour(contour, data_path):
    filename = 'IM-%s-%04d.dcm' % (SAX_SERIES[contour.case], contour.img_no)
    full_path = os.path.join(data_path, contour.case, filename)
    f = pydicom.read_file(full_path)
    img = f.pixel_array.astype('int')
    mask = np.zeros_like(img, dtype='uint8')
    coords = np.loadtxt(contour.ctr_p ath, delimiter=' ').astype('int')
    cv2.fillPoly(mask, [coords], 1)
    if img.ndim < 3:
        img = img[..., np.newaxis]
        mask = mask[..., np.newaxis]

    return img, mask


def map_all_contours(contour_path, contour_type, shuffle=True):
    contours = [os.path.join(dirpath, f)
                for dirpath, dirnames, files in os.walk(contour_path)
                for f in fnmatch.filter(files, 'IM-0001-*-' + contour_type + 'contour-manual.txt')]
    if shuffle:
        print('Shuffling data')
        np.random.shuffle(contours)
    print('Number of examples: {:d}'.format(len(contours)))
    coutours = []
    for str in contours:
        coutours.append(Contour(str))
    # map(Contour, contours)

    return coutours


def export_all_contours(contours, data_path, crop_size):
    print('\nProcessing {:d} images and labels ...\n'.format(len(contours)))
    images = np.zeros((len(contours), crop_size, crop_size, 1), dtype=np.float32)
    masks = np.zeros((len(contours), crop_size, crop_size, 1), dtype=np.float32)
    for idx, contour in enumerate(contours):
        img, mask = read_contour(contour, data_path)
        img = center_crop(img, crop_size=crop_size)
        mask = center_crop(mask, crop_size=crop_size)
        images[idx] = img
        masks[idx] = mask

    return images, masks


if __name__ == '__main__':
    # if len(sys.argv) < 3:
    #    sys.exit('Usage: python %s  ' % sys.argv[0])
    contour_type = 'i'
    # os.environ['CUDA_VISIBLE_DEVICES'] = sys.argv[2]
    crop_size = 176

    # epsilon = 1e-6

    print('Mapping ground truth ' + contour_type + ' contours to images in train...')
    train_ctrs = map_all_contours(TRAIN_CONTOUR_PATH, contour_type, shuffle=True)
    print('Done mapping training set')



    print('\nBuilding Train innerdataset ...\n')
    img_train, mask_train = export_all_contours(train_ctrs,
                                                TRAIN_IMG_PATH,
                                                crop_size=crop_size)
    print(img_train.shape)
    
    for i in range(260):
        img = img_train[i, :, :, :]
        mask = mask_train[i, :, :, :] * 255
        cv2.imwrite('D:\PycharmProjects\machine learning\LV seg\code\img\Im_image_%4d.png'% i, img)
        cv2.imwrite('D:\PycharmProjects\machine learning\LV seg\code\mask\Im_mask_%4d.png'% i, mask)









该文件主要读取数据集中含有左心室内膜和外膜的dicom图像以及其mask.txt 文件转换成为图像,并最终保存到指定的路径。内膜和外膜通过contour_type 来确定,i为内膜,o为外膜。
代码中用到的get_SAX_SERIES()在helps.py中

#!/usr/bin/env python2.7

import numpy as np
import cv2
from keras import backend as K


def get_SAX_SERIES():
    SAX_SERIES = {}
    with open('SAX_series.txt', 'r') as f:
        for line in f:
            if not line.startswith('#'):
                key, val = line.split(':')
                SAX_SERIES[key.strip()] = val.strip()

    return SAX_SERIES


def mvn(ndarray):
    '''Input ndarray is of rank 3 (height, width, depth).

    MVN performs per channel mean-variance normalization.
    '''
    epsilon = 1e-6
    mean = ndarray.mean(axis=(0,1), keepdims=True)
    std = ndarray.std(axis=(0,1), keepdims=True)

    return (ndarray - mean) / (std + epsilon)


def reshape(ndarray, to_shape):
    '''Reshapes a center cropped (or padded) array back to its original shape.'''
    h_in, w_in, d_in = ndarray.shape
    h_out, w_out, d_out = to_shape
    if h_in > h_out: # center crop along h dimension
        h_offset = (h_in - h_out) / 2
        ndarray = ndarray[h_offset:(h_offset+h_out), :, :]
    else: # zero pad along h dimension
        pad_h = (h_out - h_in)
        rem = pad_h % 2
        pad_dim_h = (pad_h/2, pad_h/2 + rem)
        # npad is tuple of (n_before, n_after) for each (h,w,d) dimension
        npad = (pad_dim_h, (0,0), (0,0))
        ndarray = np.pad(ndarray, npad, 'constant', constant_values=0)
    if w_in > w_out: # center crop along w dimension
        w_offset = (w_in - w_out) / 2
        ndarray = ndarray[:, w_offset:(w_offset+w_out), :]
    else: # zero pad along w dimension
        pad_w = (w_out - w_in)
        rem = pad_w % 2
        pad_dim_w = (pad_w/2, pad_w/2 + rem)
        npad = ((0,0), pad_dim_w, (0,0))
        ndarray = np.pad(ndarray, npad, 'constant', constant_values=0)
    
    return ndarray # reshaped


def center_crop(ndarray, crop_size):
    '''Input ndarray is of rank 3 (height, width, depth).

    Argument crop_size is an integer for square cropping only.

    Performs padding and center cropping to a specified size.
    '''
    h, w, d = ndarray.shape
    if crop_size == 0:
        raise ValueError('argument crop_size must be non-zero integer')
    
    if any([dim < crop_size for dim in (h, w)]):
        # zero pad along each (h, w) dimension before center cropping
        pad_h = (crop_size - h) if (h < crop_size) else 0
        pad_w = (crop_size - w) if (w < crop_size) else 0
        rem_h = pad_h % 2
        rem_w = pad_w % 2
        pad_dim_h = (pad_h/2, pad_h/2 + rem_h)
        pad_dim_w = (pad_w/2, pad_w/2 + rem_w)
        # npad is tuple of (n_before, n_after) for each (h,w,d) dimension
        npad = (pad_dim_h, pad_dim_w, (0,0))
        ndarray = np.pad(ndarray, npad, 'constant', constant_values=0)
        h, w, d = ndarray.shape
    # center crop
    h_offset = int((h - crop_size) / 2)
    w_offset = int((w - crop_size) / 2)
    cropped = ndarray[h_offset:(h_offset+crop_size),
                      w_offset:(w_offset+crop_size), :]

    return cropped


def lr_poly_decay(model, base_lr, curr_iter, max_iter, power=0.5):
    lrate = base_lr * (1.0 - (curr_iter / float(max_iter)))**power
    K.set_value(model.optimizer.lr, lrate)

    return K.eval(model.optimizer.lr)


def dice_coef(y_true, y_pred):
    intersection = np.sum(y_true * y_pred, axis=None)
    summation = np.sum(y_true, axis=None) + np.sum(y_pred, axis=None)
    
    return 2.0 * intersection / summation


def jaccard_coef(y_true, y_pred):
    intersection = np.sum(y_true * y_pred, axis=None)
    union = np.sum(y_true, axis=None) + np.sum(y_pred, axis=None) - intersection

    return float(intersection) / float(union)



这个文件中主要是有些辅助文件,'SAX_series.txt内容如下:

# challenge training
SC-HF-I-1: 0004
SC-HF-I-2: 0106
SC-HF-I-4: 0116
SC-HF-I-40: 0134
SC-HF-NI-3: 0379
SC-HF-NI-4: 0501
SC-HF-NI-34: 0446
SC-HF-NI-36: 0474
SC-HYP-1: 0550
SC-HYP-3: 0650
SC-HYP-38: 0734
SC-HYP-40: 0755
SC-N-2: 0898
SC-N-3: 0915
SC-N-40: 0944
# challenge online
SC-HF-I-9: 0241
SC-HF-I-10: 0024
SC-HF-I-11: 0043
SC-HF-I-12: 0062
SC-HF-NI-12: 0286
SC-HF-NI-13: 0304
SC-HF-NI-14: 0331
SC-HF-NI-15: 0359
SC-HYP-9: 0003
SC-HYP-10: 0579
SC-HYP-11: 0601
SC-HYP-12: 0629
SC-N-9: 1031
SC-N-10: 0851
SC-N-11: 0878
# challenge validation
SC-HF-I-5: 0156
SC-HF-I-6: 0180
SC-HF-I-7: 0209
SC-HF-I-8: 0226
SC-HF-NI-7: 0523
SC-HF-NI-11: 0270
SC-HF-NI-31: 0401
SC-HF-NI-33: 0424
SC-HYP-6: 0767
SC-HYP-7: 0007
SC-HYP-8: 0796
SC-HYP-37: 0702
SC-N-5: 0963
SC-N-6: 0984
SC-N-7: 1009

这部分主要讲述了如何如何把数据集转化为png。

二.Fcn和Unet模型的搭建

模型的搭建工作主要放在mode.py

#!/usr/bin/env python2.7
'''
主要搭建两个模型,FCN,UNET
均使用dice_coef_loss作为损失函数,
评估参数使用:acc,dice_coef,jaccard_coef
'''
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Lambda,concatenate
from keras.layers import Input, average
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, BatchNormalization,LeakyReLU
from keras.layers import ZeroPadding2D, Cropping2D, UpSampling2D
from keras import backend as K
import tensorflow as tf
#FCN模型使用到的一种防止过拟合的方法,具体参考论文及开源代码
def mvn(tensor):
    '''Performs per-channel spatial mean-variance normalization.'''
    epsilon = 1e-6
    mean = K.mean(tensor, axis=(1, 2), keepdims=True)
    std = K.std(tensor, axis=(1, 2), keepdims=True)
    mvn = (tensor - mean) / (std + epsilon)

    return mvn


def crop(tensors):
    '''
        List of 2 tensors, the second tensor having larger spatial dimensions.
     '''
    from keras.layers import Cropping2D
    h_dims, w_dims = [], []
    for t in tensors:
        b, h, w, d = K.get_variable_shape(t)
        h_dims.append(h)
        w_dims.append(w)
    crop_h, crop_w = (h_dims[1] - h_dims[0]), (w_dims[1] - w_dims[0])
    rem_h = crop_h % 2
    rem_w = crop_w % 2
    crop_h_dims = (int(crop_h / 2), int(crop_h / 2 + rem_h))
    crop_w_dims = (int(crop_w / 2), int(crop_w / 2 + rem_w))
    cropped = Cropping2D(cropping=(crop_h_dims, crop_w_dims))(tensors[1])

    return cropped
def dice_coef(y_true, y_pred, smooth=0.0):
        '''Average dice coefficient per batch.'''


        axes = (1, 2, 3)
        intersection = K.sum(y_true * y_pred, axis=axes)
        summation = K.sum(y_true, axis=axes) + K.sum(y_pred, axis=axes)

        return K.mean((2.0 * intersection + smooth) / (summation + smooth), axis=0)

def dice_coef_loss(y_true, y_pred):
        return 1.0 - dice_coef(y_true, y_pred, smooth=10.0)

def jaccard_coef(y_true, y_pred, smooth=0.0):
        '''Average jaccard coefficient per batch.'''
        axes = (1, 2, 3)
        intersection = K.sum(y_true * y_pred, axis=axes)
        union = K.sum(y_true, axis=axes) + K.sum(y_pred, axis=axes) - intersection
        return K.mean((intersection + smooth) / (union + smooth), axis=0)
#定义unet里面的卷积+batchnormnaziton 以方便书写
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same'):
    x = Conv2D(nb_filter,  kernel_size, strides=strides, padding=padding, activation='relu')(x)
    x = BatchNormalization(axis=3)(x)
   # x = LeakyReLU(alpha=0.1)(x)
    return x
#unet模型的框架搭建
def unet_model(size):

    inpt = Input(shape=size)
    conv1 = Conv2d_BN(inpt, 64, (3, 3))
    conv1 = Conv2d_BN(conv1, 64, (3, 3))
    pool1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv1)

    conv2 = Conv2d_BN(pool1, 128, (3, 3))
    conv2 = Conv2d_BN(conv2, 128, (3, 3))
    pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv2)

    conv3 = Conv2d_BN(pool2, 256, (3, 3))
    conv3 = Conv2d_BN(conv3, 256, (3, 3))
    pool3 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv3)

    conv4 = Conv2d_BN(pool3, 512, (3, 3))
    conv4 = Conv2d_BN(conv4, 512, (3, 3))
    pool4 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv4)

    conv5 = Conv2d_BN(pool4, 1024, (3, 3))
    conv5 = Conv2d_BN(conv5, 1024, (3, 3))
    conv5 = Dropout(0.5)(conv5)

    # convt1 = Conv2dT_BN(conv5, 512, (3, 3))
    # convt1 = Conv2dT_BN(conv5,size=(2, 2))
    up1 = UpSampling2D(size=(2, 2))(conv5)
    convt1 = Conv2d_BN(up1, 512, (2, 2))
    concat1 = concatenate([conv4, convt1], axis=3)
    # concat1 = Dropout(0.5)(concat1)
    conv6 = Conv2d_BN(concat1, 512, (3, 3))
    conv6 = Conv2d_BN(conv6, 512, (3, 3))
    conv6 = Dropout(0.5)(conv6)

    # convt2 = Conv2dT_BN(conv6, 256, (3, 3))
    up2 = UpSampling2D(size=(2, 2))(conv6)
    convt2 = Conv2d_BN(up2, 256, (2, 2))
    concat2 = concatenate([conv3, convt2], axis=3)
    # concat2 = Dropout(0.5)(concat2)
    conv7 = Conv2d_BN(concat2, 256, (3, 3))
    conv7 = Conv2d_BN(conv7, 256, (3, 3))
    conv7 = Dropout(0.5)(conv7)

    # convt3 = Conv2dT_BN(conv7, 128, (3, 3))
    up3 = UpSampling2D(size=(2, 2))(conv7)
    convt3 = Conv2d_BN(up3, 128, (2, 2))
    concat3 = concatenate([conv2, convt3], axis=3)
    # concat3 = Dropout(0.5)(concat3)
    conv8 = Conv2d_BN(concat3, 128, (3, 3))
    conv8 = Conv2d_BN(conv8, 128, (3, 3))
    conv8 = Dropout(0.5)(conv8)

    # convt4 = Conv2dT_BN(conv8, 64, (3, 3))
    up4 = UpSampling2D(size=(2, 2))(conv8)
    convt4 = Conv2d_BN(up4, 64, (2, 2))
    concat4 = concatenate([conv1, convt4], axis=3)
    # concat4 = Dropout(0.5)(concat4)
    conv9 = Conv2d_BN(concat4, 64, (3, 3))
    conv9 = Conv2d_BN(conv9, 64, (3, 3))

    outpt = Conv2D(filters=1, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='sigmoid')(conv9)
    model = Model(inputs=inpt, outputs=outpt)

   # sgd = optimizers.SGD(lr=0.01, momentum=0.9, nesterov=True)
    model.compile(optimizer='Adagrad', loss='binary_crossentropy',
                  metrics=['accuracy', dice_coef, jaccard_coef])#Adam

    return model




def fcn_model(input_shape, num_classes, weights=None):
    ''' "Skip" FCN architecture similar to Long et al., 2015
    https://arxiv.org/abs/1411.4038
    '''



    if num_classes == 2:
        num_classes = 1
        loss = dice_coef_loss
        activation = 'sigmoid'
    else:
        loss = 'categorical_crossentropy'
        activation = 'softmax'

    kwargs = dict(
        kernel_size=3,
        strides=1,
        activation='relu',
        padding='same',
        use_bias=True,
        kernel_initializer='glorot_uniform',
        bias_initializer='zeros',
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        trainable=True,
    )

    data = Input(shape=input_shape)
    dataB= BatchNormalization(axis=3)(data)
    pad = ZeroPadding2D(padding=5, name='pad')(dataB)

    conv1 = Conv2D(filters=64, name='conv1', **kwargs)(pad)
    conv1B = BatchNormalization(axis=3)(conv1)

    conv2 = Conv2D(filters=64, name='conv2', **kwargs)(conv1B)
    conv2B = BatchNormalization(axis=3)(conv2)

    conv3 = Conv2D(filters=64, name='conv3', **kwargs)(conv2B)
    conv3B = BatchNormalization(axis=3)(conv3)

    pool1 = MaxPooling2D(pool_size=3, strides=2,
                         padding='valid', name='pool1')(conv3B)

    conv4 = Conv2D(filters=128, name='conv4', **kwargs)(pool1)
    conv4B = BatchNormalization(axis=3)(conv4)

    conv5 = Conv2D(filters=128, name='conv5', **kwargs)(conv4B)
    conv5B = BatchNormalization(axis=3)(conv5)

    conv6 = Conv2D(filters=128, name='conv6', **kwargs)(conv5B )
    conv6B = BatchNormalization(axis=3)(conv6)

    conv7 = Conv2D(filters=128, name='conv7', **kwargs)(conv6B)
    conv7B = BatchNormalization(axis=3)(conv7)
    pool2 = MaxPooling2D(pool_size=3, strides=2,
                         padding='valid', name='pool2')(conv7B)

    conv8 = Conv2D(filters=256, name='conv8', **kwargs)(pool2)
    conv8B = BatchNormalization(axis=3)(conv8)

    conv9 = Conv2D(filters=256, name='conv9', **kwargs)(conv8B)
    conv9B = BatchNormalization(axis=3)(conv9)

    conv10 = Conv2D(filters=256, name='conv10', **kwargs)(conv9B)
    conv10B = BatchNormalization(axis=3)(conv10)

    conv11 = Conv2D(filters=256, name='conv11', **kwargs)(conv10B)
    conv11B = BatchNormalization(axis=3)(conv11)
    pool3 = MaxPooling2D(pool_size=3, strides=2,
                         padding='valid', name='pool3')(conv11B)
    drop1 = Dropout(rate=0.5, name='drop1')(pool3)

    conv12 = Conv2D(filters=512, name='conv12', **kwargs)(drop1)
    conv12B = BatchNormalization(axis=3)(conv12)

    conv13 = Conv2D(filters=512, name='conv13', **kwargs)(conv12B)
    conv13B = BatchNormalization(axis=3)(conv13)

    conv14 = Conv2D(filters=512, name='conv14', **kwargs)(conv13B)
    conv14B = BatchNormalization(axis=3)(conv14)

    conv15 = Conv2D(filters=512, name='conv15', **kwargs)(conv14B )
    conv15B = BatchNormalization(axis=3)(conv15)
    drop2 = Dropout(rate=0.5, name='drop2')(conv15B)

    score_conv15 = Conv2D(filters=num_classes, kernel_size=1,
                          strides=1, activation=None, padding='valid',
                          kernel_initializer='glorot_uniform', use_bias=True,
                          name='score_conv15')(drop2)
    upsample1 = Conv2DTranspose(filters=num_classes, kernel_size=3,
                                strides=2, activation=None, padding='valid',
                                kernel_initializer='glorot_uniform', use_bias=False,
                                name='upsample1')(score_conv15)
    score_conv11 = Conv2D(filters=num_classes, kernel_size=1,
                          strides=1, activation=None, padding='valid',
                          kernel_initializer='glorot_uniform', use_bias=True,
                          name='score_conv11')(conv11B)
    crop1 = Lambda(crop, name='crop1')([upsample1, score_conv11])
    fuse_scores1 = average([crop1, upsample1], name='fuse_scores1')

    upsample2 = Conv2DTranspose(filters=num_classes, kernel_size=3,
                                strides=2, activation=None, padding='valid',
                                kernel_initializer='glorot_uniform', use_bias=False,
                                name='upsample2')(fuse_scores1)
    score_conv7 = Conv2D(filters=num_classes, kernel_size=1,
                         strides=1, activation=None, padding='valid',
                         kernel_initializer='glorot_uniform', use_bias=True,
                         name='score_conv7')(conv7B )
    crop2 = Lambda(crop, name='crop2')([upsample2, score_conv7])
    fuse_scores2 = average([crop2, upsample2], name='fuse_scores2')

    upsample3 = Conv2DTranspose(filters=num_classes, kernel_size=3,
                                strides=2, activation=None, padding='valid',
                                kernel_initializer='glorot_uniform', use_bias=False,
                                name='upsample3')(fuse_scores2)
    crop3 = Lambda(crop, name='crop3')([data, upsample3])
    predictions = Conv2D(filters=num_classes, kernel_size=1,
                         strides=1, activation=activation, padding='valid',
                         kernel_initializer='glorot_uniform', use_bias=True,
                         name='predictions')(crop3)

    model = Model(inputs=data, outputs=predictions)
    if weights is not None:
        model.load_weights(weights)
    sgd = optimizers.SGD(lr=0.01, momentum=0.9, nesterov=True)
    model.compile(optimizer=sgd, loss=loss,
                  metrics=['accuracy', dice_coef, jaccard_coef])

    return model

Unet模型中加入了dropout以防止过拟合,具体的loss、optimizers参考代码,可以自行进行修改及设置。

三.读入数据进行训练

该部分主要是讲训练集图片读入到网络中进行训练,文件名问:main.py

'''
filename: main.py
description: 从磁盘读取图片,划分为训练集和测试集,加载模型,可以选择加载FCN和UNET模型,并对模型进行训练
此训练基于的是tensorflow_cpu后台,若使用GPU训练请查阅相关资料对代码进行部分修改,
GPU训练方法可以参考改博客:https://blog.csdn.net/zong596568821xp/article/details/86494916
训练完一个训练周期后保存一次模型,训练过程保存在keras_log下,可以使用tensorboard进行查看
'''


import numpy as np
import random
import os
import sys
import tensorflow as tf
from keras.models import save_model, load_model, Model
from keras.layers import Input, Dropout, BatchNormalization, LeakyReLU, concatenate
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv2DTranspose
import matplotlib.pyplot as plt
from keras.callbacks import TensorBoard
from keras import callbacks
from model import fcn_model, unet_model,unet
from model import dice_coef_loss, dice_coef, jaccard_coef


from keras.layers import ZeroPadding2D, Cropping2D
import tensorflow as tf


import cv2
np.random.seed(678)
tf.set_random_seed(5678)

# --- get data ---
'数据集存放位置,读入以后直接放入内存,数据集过大的时候不推荐这种方式,可以考虑tfrecord读取方式'
data_location = "D:\PycharmProjects\machine learning\LV seg\innerdataset\img_pro"
train_data = []  # create an empty list
for dirName, subdirList, fileList in os.walk(data_location):  #遍历给出路径下的所有文件名,fileList为文件名列表,具体请参考os.walkf方法
    for filename in fileList:
        train_data.append(os.path.join(dirName, filename)) #将图像的完整路径读入进来存放到一个list


data_location = "D:\PycharmProjects\machine learning\LV seg\innerdataset\mask_pro"
train_data_gt = []  # create an empty list
for dirName, subdirList, fileList in os.walk(data_location):
    for filename in fileList:
            train_data_gt.append(os.path.join(dirName, filename))

train_images = np.zeros(shape=(250, 176, 176, 1))  #存放训练数据
train_labels = np.zeros(shape=(250, 176, 176, 1))

'''
遍历文件List读入图像
'''
for file_index in range(len(train_data)):
    train_images[file_index, :, :] = np.expand_dims(np.resize(cv2.imread(train_data[file_index], 0), new_shape=(176, 176)), axis=2)
    train_labels[file_index, :, :] = np.expand_dims(np.resize(cv2.imread(train_data_gt[file_index], 0), new_shape=(176, 176)), axis=2)#cv2.IMREAD_GRAYSCALE
'''
将读取的图片的像素值归一化
归一化公式:(f-fmin)/fmax - fmin
'''
images = (train_images - train_images.min()) / (train_images.max() - train_images.min())
labels = (train_labels - train_labels.min()) / (train_labels.max() - train_labels.min())
'''
将处里后的数据划分为训练集,测试集,如果有必要可以添加划分为验证集,在fit方法里面也可以直接指定将训练集划分
为训练集和验证集
'''
print("creating train images :\n")
trainX = images[0:250]  ##训练集
trainY = labels[0:250]
print("creating train images done total:{}\n".format(trainY.shape[0]))
'''
print("creating validtion images :\n")
validtionX = images[120:135]  ##验证集
validtionY = labels[120:135]
print("creating validtion images done total:{}\n".format(validtionX.shape[0]))
'''
'''
print("creating test images :\n")
predictionX = images[132:135]  ###用于预测也就是测试集
predictionY = labels[132:135]
print("creating test images done \n")
'''

'''
该类主要是实现next—batch功能,不用详细探究
'''
class DataSet(object):

    def __init__(self, images, labels, num_examples):
        self._images = images
        self._labels = labels
        self._epochs_completed = 0  # 完成遍历轮数
        self._index_in_epochs = 0  # 调用next_batch()函数后记住上一次位置
        self._num_examples = num_examples  # 训练样本数

    def next_batch(self, batch_size, fake_data=False, shuffle=True):
        start = self._index_in_epochs

        if self._epochs_completed == 0 and start == 0 and shuffle:
            index0 = np.arange(self._num_examples)

            np.random.shuffle(index0)

            self._images = np.array(self._images)[index0]
            self._labels = np.array(self._labels)[index0]


        if start + batch_size > self._num_examples:
            self._epochs_completed += 1
            rest_num_examples = self._num_examples - start
            images_rest_part = self._images[start:self._num_examples]
            labels_rest_part = self._labels[start:self._num_examples]
            if shuffle:
                index = np.arange(self._num_examples)
                np.random.shuffle(index)
                self._images = self._images[index]
                self._labels = self._labels[index]
            start = 0
            self._index_in_epochs = batch_size - rest_num_examples
            end = self._index_in_epochs
            images_new_part = self._images[start:end]
            labels_new_part = self._labels[start:end]
            return np.concatenate((images_rest_part, images_new_part), axis=0), np.concatenate(
                (labels_rest_part, labels_new_part), axis=0)

        else:
            self._index_in_epochs += batch_size
            end = self._index_in_epochs
            return self._images[start:end], self._labels[start:end]
#tensorboard 可视化的日志问价存放路径
log_filepath = "D:\PycharmProjects\machine learning\LV seg\code\logtest" #log存储位置
tb_cb = callbacks.TensorBoard(log_dir=log_filepath, write_images=1, histogram_freq=1)
# 设置log的存储位置,将网络权值以图片格式保持在tensorboard中显示,设置每一个周期计算一次网络的权值,每层输出值的分布直方图
cbks = [tb_cb]# 这个为Tensorboard所必需转换
inpt =  (176, 176, 1) #输入网络的图片大小
#model = unet_model(inpt)         #unet 模型构建
#model = fcn_model(inpt, 2, None) #fcn模型构建

model.summary() #打印模型

#具体的实现细节自行参考dice的计算准则
def dice_coef_theoretical(y_pred, y_true):
    """Define the dice coefficient
        Args:
        y_pred: Prediction
        y_true: Ground truth Label
        Returns:
        Dice coefficient
        """

    y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32) #把y_true 拉伸为一行,便于计算,并将数据类型转化为float32
    y_pred_f = tf.cast(tf.greater(y_pred, 0.5), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred_f, [-1]), tf.float32)

    intersection = tf.reduce_sum(y_true_f * y_pred_f) #求交集
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f)#求并集
    dice = (2.0 * intersection) / (union + 0.00001) #计算dice 。

    if (tf.reduce_sum(y_pred) == 0) and (tf.reduce_sum(y_true) == 0):
        dcie = 1

    return dice
if __name__ == '__main__':


    ds = DataSet(trainX, trainY, 100) #使用时先初始化数据类对象
    train_X, train_Y = ds.next_batch(100) #默认取出全部数据,
    #详细参数请查阅fit方法
    history = model.fit(train_X, train_Y, batch_size=4, epochs=50, verbose=1, callbacks=cbks, validation_split=0.1)

    save_model(model, 'test.h5')  #使用不同模型记得更换名字
    '''
    #加载训练好的模型进行预测
    #后面为了分开,测试效果文件使用utils文件进行
    #显示不是灰度图像的原因,opencv的色彩空间为BGR,matlab为RGB 此处为了方便使用plt.imshow()方法
    model = load_model('fcn1.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef':dice_coef, 'jaccard_coef': jaccard_coef})
    ds1 = DataSet(validtionX, validtionY, 10) #加载测试集数据
    for i in range(10): #单张图片进行预测,查看效果
        test_X, test_Y = ds1.next_batch(1)
        pred_Y = model.predict(test_X)
        acc = dice_coef_theoretical(pred_Y[0, :, :, 0], test_Y[0, :, :, 0])
        sess = tf.Session()
        print("Dice value{}".format(sess.run(acc)))
        print("\n")
        ii = 0
        plt.figure()
        plt.imshow(test_X[ii, :, :, 0])
        plt.title("tranX{}".format(ii))
        plt.axis('off')
        plt.figure()
        plt.imshow(test_Y[ii, :, :, 0])
        plt.title("tranY{}".format(ii))
        plt.axis('off')
        plt.figure()
        plt.imshow(pred_Y[ii, :, :, 0])
        plt.title("Predict{}".format(ii))
        plt.axis('off')
        plt.show()
    '''

到此便实现了模型对图像的分割。
四.总结
此次试验运行在cpu环境下,运行速度较慢,训练50个周期的时间在8个小时左右。此外该试验并没有对loss、optimizers等进行调节查看效果,此文仅供实现分割图像的参考。具体的模型及算法请参考论文

你可能感兴趣的:(AI)