用keras实现fcn全卷积网络的图像分割

一开始对图像分割,全卷积网络很懵逼,根本不知道说的是什么,怎样实现,直到发现了这篇博文https://blog.csdn.net/u012931582/article/details/70215756/
在此感谢博客上大佬精彩的文章

首先谈一下FCN,FCN的全称是Fully Convolutional Networks,首次出现在Jonathan Long大神14年发表的经典论文“Fully Convolutional Networks for Semantic Segmentation”中 https://arxiv.org/abs/1411.4038 。文中首次提出了FCN全卷积结构,把传统分类网络的全连接替换为了卷积层,运用upsampling和 concatenate操作,对图像进行了像素级的分类,完成了语义分割的任务。
Unet http://www.arxiv.org/pdf/1505.04597.pdf 是15年提出的,结构上主要借鉴了‘Fully Convolutional Networks for Semantic Segmentation’的思想,在先前的fcn网络里进行了改动,满足了医学图像领域中样本尺度大,对网络性能需求高的要求。
个人觉得fcn这种结构已经和cnn一样是一类结构而不是一个特定的网络了,因此在这篇记录里面没有详细说明区别,造成误解向大家道歉。关于伪代码中损失函数问题,代码可能是我在测试前景背景分割时复制下来的,所以损失函数用了二分类,如果要改为多分类,可以改为softmax激活函数和对应的损失函数。

图像分割源于医疗领域的需求,现阶段图像分割还主要是应用在医疗影像处理方面。什么是图像分割呢?查阅论文之后,用我的理解简单的概括一下:在处理图像分类问题时,我们的输出层总是由全连接层(FC)组成,这就直接导致了输出的标签只可能作为一个标签来使用,不能对应到图像上,早些年的研究为了使预测结果能直观的反映在原图上,就用每一个像素点或者几个邻近的像素块作为一个标签来计算,这样做的直接结果导致了训练过程中产生了大量的计算,并且这些像素点由于相互临近,导致很大一部分计算都是重复的(也导致了类别和定位不能同时精确的问题)。为了解决这个问题,医学领域提出了FCN理论,他用最后几层上卷积的map来提升分类的精度,用最初几步下卷积的map提升定位的精度

用keras实现fcn全卷积网络的图像分割_第1张图片
并在训练的时候以预处理的灰度图作为标签(mask)进行训练
这是我在GitHub找到的一个代码加以修改

from __future__ import print_function
from keras.preprocessing.image import ImageDataGenerator
import numpy as np 
import os
import glob
import skimage.io as io
import skimage.transform as trans
import cv2
##这一步获取了原图和mask
def trainGenerator(batch_size,train_path,train_path2,image_folder,mask_folder,aug_dict,
                    flag_multi_class, num_class ,image_color_mode = "grayscale",
                    mask_color_mode = "grayscale",image_save_prefix  = "image",mask_save_prefix  = "mask",save_to_dir = None,target_size = (256,256),seed = 1):

    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes = [image_folder],
        class_mode = None,
        color_mode = image_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = image_save_prefix,
        seed = seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path2,
        classes = [mask_folder],
        class_mode = None,
        color_mode = mask_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = mask_save_prefix,
        seed = seed)
    train_generator = zip(image_generator, mask_generator)
    for (img,mask) in train_generator:
        img,mask = adjustData(img,mask,
                           flag_multi_class,   ##参数(flag_multi_class)用来开启多分类
                           num_class)##这里对图像进行了处理,用以减小计算量(函数定义在下面)
        yield (img,mask)##用生成器进行迭代数据,可以传入model.fit_generator()这个函数进行训练

def adjustData(img,mask,flag_multi_class,num_class):
    if (flag_multi_class):#如果多分类,在mask添加多层,每层对应一个类别
        img = img / 255
        mask = mask[:, :, :, 0] if (len(mask.shape) == 4) else mask[:, :, 0]
        new_mask = np.ones(mask.shape + (len(num_class),))
        for i in range(len(num_class)):
            # for one pixel in the image, find the class in mask and convert it into one-hot vector
            # index = np.where(mask == i)
            # index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i)
            # new_mask[index_mask] = 1
            new_mask[mask == num_class[i], i] = 0
        new_mask = np.reshape(new_mask, (new_mask.shape[0], new_mask.shape[1], new_mask.shape[2],
                                         new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask, (
            new_mask.shape[0], new_mask.shape[1], new_mask.shape[2]))
        mask = new_mask
    elif (np.max(img) > 1):#如果不是多分类,直接对img,mask进行操作,不难理解
        img = img / 255
        mask = mask / 255
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0
    return (img,mask)

获取数据完成之后要定义模型

mport numpy as np 
import os
import skimage.io as io
import skimage.transform as trans
import numpy as np
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras


def unet(pretrained_weights = None,input_size = (512, 512, 1)):
    inputs = Input(input_size)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6],)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv9 = Conv2D(8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv10 = Conv2D(4, 1, activation = 'sigmoid')(conv9) ##注意output层,我训练的是4分类模型,
                                                         ##所以output有4通道
    model = Model(input = inputs, output = conv10)

    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    #model.summary()

    if(pretrained_weights):
    	model.load_weights(pretrained_weights)

    return model

接着就开启训练吧

ata_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')
myGene = trainGenerator(1,img_path,         #图像路径
                              mask_path,    #mask路径
                            img_files,      #mask文件夹名称
                            mask_files,     #图像文件夹名称(不懂的话看此函数定义和keras文档)
                            data_gen_args,True,#
                            num_class,target_size=(256,256))

    model = unet()
    model_checkpoint = ModelCheckpoint('saver\module.hdf5', monitor='loss',verbose=1, save_best_only=True)
    model.fit_generator(myGene,steps_per_epoch=300,epochs=1,callbacks=[model_checkpoint])

训练好之后要对test进行预测,这几个函数自己应该可以定义的

testGene = testGenerator("data/membrane/test")
results = model.predict_generator(testGene,30,verbose=1)

然后吧result写入t图片中

def save(save_path, result,num_class):
    for i, item in enumerate(result):
        pred = np.ones((256,256,1))*128
        for j in range(len(num_class)):
            item_prrd = item[:, :, j]
            pred[item_prrd < 0.5,:] = num_class[j]
        cv2.imwrite(os.path.join(save_path, "%d_predict.bmp" % i),pred)

你可能感兴趣的:(Keras)