UNet网络实现及解析

简介

 Unet是受到FCN启发针对医学图像做语义分割, 且可以利用少量的数据学习到一个对边缘提取十分鲁棒的模型, 在生物医学图像分割领域有很大作用。

网络结构

  网络结构如下图:

UNet网络实现及解析_第1张图片

  如上图其结构如英文字母u,所以被命名为unet。其建立在FCN的架构上, 首先是从左侧输入开始的一系列卷积层,这里主要有5层,目的是用来提取图片的特征, 这里可以使用vgg或者resnet等经典的特征提取网络。然后是右侧的结构, 首先从最下层开始将提取出的特征进行上采样,上采样后的特征与其上一层的特征的形状相同, 然后将两个特征聚合在一起,并且添加卷积层进行通道数的缩减,然后对缩减后的特征进行上采样, 并重复之前的操作。最后上采样到与原图的形状相同的时候,再添加一个对每个像素点进行分类的卷积层。

keras实现

整体实现

 keras实现如下:

def Unet(input_shape=(256,256,3),num_class=1):

    input= Input(input_shape)

    feat1,feat2,feat3,feat4,feat5 = VGG16(input)

    channels = [64, 128, 256, 512]

    P5_up = UpSampling2D(size=(2,2))(feat5)

    P4 = Concatenate(axis=3)([feat4,P5_up])

    P4 = Conv2D(channels[3],3,activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02))(P4)
    P4 = Conv2D(channels[3],3,activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02))(P4)

    P4_up = UpSampling2D(size=(2,2))(P4)

    P3 = Concatenate(axis=3)([P4_up,feat3])

    P3 = Conv2D(channels[2], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P3)
    P3 = Conv2D(channels[2], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P3)

    P3_up = UpSampling2D(size=(2, 2))(P3)

    P2 = Concatenate(axis=3)([P3_up, feat2])

    P2 = Conv2D(channels[1], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P2)
    P2 = Conv2D(channels[1], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P2)

    P2_up = UpSampling2D(size=(2, 2))(P2)

    P1 = Concatenate(axis=3)([P2_up, feat1])

    P1 = Conv2D(channels[0], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P1)
    P1 = Conv2D(channels[0], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P1)

    P1 = Conv2D(num_class,1,activation="sigmoid")(P1)

    model = Model(inputs=input,outputs=P1)
    return model

  首先是第5行调用的VGG16方法,这里是对应网络结构左侧的内容。这里使用VGG网络来做特征提取。 其返回了5个特征层的特征。

  然后是第7行,这里定义了上采样后每层的特征数。然后是第9行,先对最底层的特征(feat5) 进行上采样;然后是第11行使用Concatenate网络将fea5上采样的结果(P5_up)与feat4进行连接; 然后是第13行和第14行使用两个3*3的卷积来进行通道数的缩减。

  然后是一步步上采样和卷积直到第37行,用一个1*1的卷积来做像素点的分类。 最后是第39行根据上述的结构来创建模型。

vgg特征提取

  在上文提到了特征提取是使用的是一个名为VGG16的方法,该方法内容如下:

def VGG16(input):
    x = Conv2D(64,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
               name="block1_conv1")(input)

    x = Conv2D(64,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
               name="block1_conv2")(x)

    feat1 = x

    x = MaxPooling2D((2,2),strides=(2,2),name="block1_pool")(x)

    x = Conv2D(128,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
               name="block2_conv1")(x)

    x = Conv2D(128,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
               name="block2_conv2")(x)

    feat2 = x

    x = MaxPooling2D((2,2),strides=(2,2),name="block2_pool")(x)

    x = Conv2D(256,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
               name="block3_conv1")(x)

    x = Conv2D(256,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
               name="block3_conv2")(x)

    x = Conv2D(256, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block3_conv3")(x)

    feat3 = x

    x = MaxPooling2D((2, 2), strides=(2, 2), name="block3_pool")(x)

    x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block4_conv1")(x)

    x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block4_conv2")(x)

    x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block4_conv3")(x)

    feat4 = x

    x = MaxPooling2D((2, 2), strides=(2, 2), name="block4_pool")(x)

    x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block5_conv1")(x)

    x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block5_conv2")(x)

    x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
               name="block5_conv3")(x)

    feat5 = x
    return feat1,feat2,feat3,feat4,feat5

  VGG的实现实际很简单,首先是两个3*3的卷积,这两个卷积的结果作为feat1; 然后是一个最大池化层,池化后再接两个卷积层,卷积的结果作为feat2,然后一直重复这个操作, 直到feat5。最后再将feat1到feat都返回。

模型训练与预测

模型训练

  这里的模型是使用keras实现的,keras模型的训练很简单,只需要调用fit方法 或fit_generator方法便可。

  在训练前只需要处理好模型的输入与输出便可。语义分割需要的如下图所示:

UNet网络实现及解析_第2张图片

  上图是一个语义分割的数据集,左边的图片是数据集的输入,右边的是输出。 这是一个二分类的语义分割任务,右边白色的标签为1,黑色的标签为0。

  主要训练代码如下:

def train():
    #图片与标签路径
    train_csv = "XXX"
    train_image = "XXX"
    
    #模型保存路径
    checkpoint_path = "../savemodel/unet_tld2.{epoch:02d}-{val_loss:.2f}.h5"
    #除了输入与输出
    trains, vals, test = utils.get_HSAgenerator_all(train_csv, train_image,train_batch_size=8)

    cb = [
        ModelCheckpoint(checkpoint_path, verbose=0),
        TensorBoard(log_dir="../logs/unet_tld2")
    ]


    model = Unet(input_shape=(256, 256, 3))
    lr = 1e-4
    a = Adam(learning_rate=lr)
    model.compile(loss=utils.lossdif2, optimizer=a, metrics="acc")

    model.fit_generator(trains, steps_per_epoch=2625, epochs=50, validation_data=vals,
                        validation_steps=375, callbacks=cb)

  训练的代码主要如上所示,这里主要是使用的fit_generator方法来训练。 他要求传入的输入是一个generator类,这个类可以使用yield关键字来快速实现。

  然后是callbacks,这里定义了两个类:ModelCheckpoint和TensorBoard。 ModelCheckpoint的主要作用是在模型训练完成指定的epoch后保存模型,默认是每个批次都保存模型。 TensorBoard主要的作用是存储TensorBoard需要的数据。

  最后是模型编译需要的一些参数。损失函数一般可以交叉熵函数,optimizer可以使用adam或 sgd(随机梯度下降)等,metrics这里使用的acc(准确率)。

模型预测

  这里的模型是使用keras实现的,keras模型的预测也很简单。主要代码如下:

def predict_hsa():
   
    name = "XXX"
    model_path = "XXX"

    r = []

    m = Unet(input_shape=(256, 256, 3))
    m.load_weights(model_path)

 
    threshold = 0.5
    image = cv2.imread(name)
    image = cv2.resize(image,(256,256))

    p = m.predict(np.array([image]))
    mask = p[0]
    mask = np.where(mask >= threshold, 1, 0).astype(np.uint8)
    mask = cv2.resize(mask, (512, 512))
    utils.show(image,mask)

  首先是第8行和第9行,这里先定义了deeplab模型,然后加载训练好的模型。然后是第12行到 第14行,这里主要是读取图片并resize成模型需求的形状。然后是第16行调用predict方法进行预测,最 后是第17行到第20行,这里在处理模型预测结果并显示。

  模型的预测效果主要如下:

UNet网络实现及解析_第3张图片

你可能感兴趣的:(深度学习,神经网络,keras)