基于卷积神经网络特征图的二值图像分割

       目标检测是当前大火的一个研究方向,FasterRCNN、Yolov3等一系列结构也都在多目标检测的各种应用场景或者竞赛中取得了很不错的成绩。但是想象一下,假设我们需要通过图像检测某个产品上是否存在缺陷,或者通过卫星图判断某片海域是否有某公司的船只,再或者需要研发一套无人驾驶中基于图像的避障设备。这些问题的共同特点是,我们只需要检测出某种特定目标在图片中的位置,并不需要在同一幅图中识别出多个目标。这种时候,FasterRCNN或者Yolov3等算法当然完全能够胜任,但是多少有些杀鸡用牛刀的感觉,因为考虑到这些网络需要相对较多的计算资源。当我们仅仅需要检测某一类特定目标的话,我们更希望网络能够专注于学习到那一个特定目标的特征。15年所提出的U-net网络正是通过多个多通道特征图最大化的利用输入图片的特征,以实现目标的二值图像分割,并在kaggle上的各类图像分割相关赛事中被广泛使用。U-net论文:https://arxiv.org/abs/1505.04597,我写的相关博客:https://blog.csdn.net/shi2xian2wei2/article/details/84345025.。

       这里仅建立一个比较简单的网络模型,来对基于卷积神经网络特征图的二值图像分割方法进行说明。网络基于keras建立,结构如下:

def Conv2d_BN(x, nb_filter, kernel_size, strides=(1,1), padding='same'):
    x = Conv2D(nb_filter, kernel_size, strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3)(x)
    x = LeakyReLU(alpha=0.1)(x)
    return x

def Conv2dT_BN(x, filters, kernel_size, strides=(2,2), padding='same'):
    x = Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3)(x)
    x = LeakyReLU(alpha=0.1)(x)
    return x

inpt = Input(shape=(input_size_1, input_size_2, 3))
x = Conv2d_BN(inpt, 4, (3, 3))
x = MaxPooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 8, (3, 3))
x = MaxPooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 16, (3, 3))
x = AveragePooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 32, (3, 3))
x = AveragePooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 64, (3, 3))
x = Dropout(0.5)(x)
x = Conv2d_BN(x, 64, (1, 1))
x = Dropout(0.5)(x)
x = Conv2dT_BN(x, 32, (3, 3))
x = Conv2dT_BN(x, 16, (3, 3))
x = Conv2dT_BN(x, 8, (3, 3))
x = Conv2dT_BN(x, 4, (3, 3))

x = Conv2DTranspose(filters=3,kernel_size=(3,3),strides=(1,1),padding='same',activation='sigmoid')(x)

model = Model(inpt, x)
model.summary()

       网络输入的图片大小为256*256*3。这样搭起来的网络,只有50000+参数,如果是实际应用的话,再优化一下放到移动设备里边实时性应该还是没问题的。

       由于网络输出的二值图像分割结果尺寸应该和原始图片保持一致,因此在网络使用了池化层对图片进行压缩之后,需要进行上采样来对图片的尺寸进行还原。一般而言,神经网络中常用的上采样操作是up pooling或者转置卷积,插值的话比较少见。个人觉得转置卷积的效果会优于up pooling。网络使用最大值池化层来突出原始图像的边缘特征,同时均值池化层用来保留图像中的位置特征,Dropout层加入噪声防止过拟合。卷积层与反卷积层基本呈对称结构,来方便对训练集标签进行更为自然的学习。

       网络训练与测试所使用的数据集,是在网上找到一些无异物的铁路图像作为背景,同时基于VOC2012数据集中的图片以及其提供的SegmentationObject标签,将目标物体随机缩放、旋转后,与背景铁路图像随机组合生成伪造数据。数据集和标签大概长下面这样:

基于卷积神经网络特征图的二值图像分割_第1张图片基于卷积神经网络特征图的二值图像分割_第2张图片

基于卷积神经网络特征图的二值图像分割_第3张图片基于卷积神经网络特征图的二值图像分割_第4张图片

基于卷积神经网络特征图的二值图像分割_第5张图片基于卷积神经网络特征图的二值图像分割_第6张图片

基于卷积神经网络特征图的二值图像分割_第7张图片基于卷积神经网络特征图的二值图像分割_第8张图片

       

       生成的训练集和测试集各包含1000张图片,训练集与测试集放置的目标物体不同,训练10个Epoch。训练之后的网络对测试集的分类效果如下:

                 原始图像                         真实标签                         检测标签

基于卷积神经网络特征图的二值图像分割_第9张图片

基于卷积神经网络特征图的二值图像分割_第10张图片

基于卷积神经网络特征图的二值图像分割_第11张图片

基于卷积神经网络特征图的二值图像分割_第12张图片

       可以看出,即便是在限制条件较多的情况下,网络也能够取得较好的检测效果。对于测试集最后一张图的小目标,网络检测结果稍差,一方面是因为训练集较小,训练不完全的缘故,但也有网络容量本身的问题。在需要进行精度更高的检测的情况下,可以适当将网络扩大或加深,简单的将网络各层中卷积单元数量同时增加相同倍数就能得到更加好的结果,但相应的计算速度会有所下降。

       当然,伪造数据集最大的问题在于背景多样性不足,可能在背景更加复杂的情况下,所需要的网络容量也相应会更大。

       网络使用的完整代码如下,数据读取写的比较啰嗦,反正就那么个意思:

import numpy as np
import random
import os

from keras.models import save_model, load_model, Model
from keras.layers import Input, Dropout, BatchNormalization, LeakyReLU
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv2DTranspose
import matplotlib.pyplot as plt
from skimage import io
from skimage.transform import resize

input_name = os.listdir('train_data3/JPEGImages')

n = len(input_name)
batch_size = 8
input_size_1 = 256
input_size_2 = 256

"""
Batch_data
"""
def batch_data(input_name, n, batch_size = 8, input_size_1 = 256, input_size_2 = 256):
    rand_num = random.randint(0, n-1)
    img1 = io.imread('train_data3/JPEGImages/'+input_name[rand_num]).astype("float")
    img2 = io.imread('train_data3/TargetImages/'+input_name[rand_num]).astype("float")
    img1 = resize(img1, [input_size_1, input_size_2, 3])
    img2 = resize(img2, [input_size_1, input_size_2, 3])
    img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
    img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
    img1 /= 255
    img2 /= 255
    batch_input = img1
    batch_output = img2
    for batch_iter in range(1, batch_size):
        rand_num = random.randint(0, n-1)
        img1 = io.imread('train_data3/JPEGImages/'+input_name[rand_num]).astype("float")
        img2 = io.imread('train_data3/TargetImages/'+input_name[rand_num]).astype("float")
        img1 = resize(img1, [input_size_1, input_size_2, 3])
        img2 = resize(img2, [input_size_1, input_size_2, 3])
        img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
        img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
        img1 /= 255
        img2 /= 255
        batch_input = np.concatenate((batch_input, img1), axis = 0)
        batch_output = np.concatenate((batch_output, img2), axis = 0)
    return batch_input, batch_output

def Conv2d_BN(x, nb_filter, kernel_size, strides=(1,1), padding='same'):
    x = Conv2D(nb_filter, kernel_size, strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3)(x)
    x = LeakyReLU(alpha=0.1)(x)
    return x

def Conv2dT_BN(x, filters, kernel_size, strides=(2,2), padding='same'):
    x = Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3)(x)
    x = LeakyReLU(alpha=0.1)(x)
    return x

inpt = Input(shape=(input_size_1, input_size_2, 3))
x = Conv2d_BN(inpt, 4, (3, 3))
x = MaxPooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 8, (3, 3))
x = MaxPooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 16, (3, 3))
x = AveragePooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 32, (3, 3))
x = AveragePooling2D(pool_size=(3,3),strides=(2,2),padding='same')(x)
x = Conv2d_BN(x, 64, (3, 3))
x = Dropout(0.5)(x)
x = Conv2d_BN(x, 64, (1, 1))
x = Dropout(0.5)(x)
x = Conv2dT_BN(x, 32, (3, 3))
x = Conv2dT_BN(x, 16, (3, 3))
x = Conv2dT_BN(x, 8, (3, 3))
x = Conv2dT_BN(x, 4, (3, 3))

x = Conv2DTranspose(filters=3,kernel_size=(3,3),strides=(1,1),padding='same',activation='sigmoid')(x)

model = Model(inpt, x)
model.summary()

model.compile(loss='mean_squared_error', optimizer='Nadam', metrics=['accuracy'])

itr = 1000
S = []
for i in range(itr):
    print("iteration = ", i+1)
    if i < 500:
        bs = 4
    elif i < 2000:
        bs = 8
    elif i < 5000:
        bs = 16
    else:
        bs = 32
    train_X, train_Y = batch_data(input_name, n, batch_size = bs)
    model.fit(train_X, train_Y, epochs=1, verbose=0)

def batch_data_test(input_name, n, batch_size = 8, input_size_1 = 256, input_size_2 = 256):
    rand_num = random.randint(0, n-1)
    img1 = io.imread('test_data3/JPEGImages/'+input_name[rand_num]).astype("float")
    img2 = io.imread('test_data3/TargetImages/'+input_name[rand_num]).astype("float")
    img1 = resize(img1, [input_size_1, input_size_2, 3])
    img2 = resize(img2, [input_size_1, input_size_2, 3])
    img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
    img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
    img1 /= 255
    img2 /= 255
    batch_input = img1
    batch_output = img2
    for batch_iter in range(1, batch_size):
        rand_num = random.randint(0, n-1)
        img1 = io.imread('test_data3/JPEGImages/'+input_name[rand_num]).astype("float")
        img2 = io.imread('test_data3/TargetImages/'+input_name[rand_num]).astype("float")
        img1 = resize(img1, [input_size_1, input_size_2, 3])
        img2 = resize(img2, [input_size_1, input_size_2, 3])
        img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
        img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
        img1 /= 255
        img2 /= 255
        batch_input = np.concatenate((batch_input, img1), axis = 0)
        batch_output = np.concatenate((batch_output, img2), axis = 0)
    return batch_input, batch_output

test_name = os.listdir('test_data3/JPEGImages')
n_test = len(test_name)

test_X, test_Y = batch_data_test(test_name, n_test, batch_size = 1)
pred_Y = model.predict(test_X)
ii = 0
plt.figure()
plt.imshow(test_X[ii, :, :, :])
plt.axis('off')
plt.figure()
plt.imshow(test_Y[ii, :, :, :])
plt.axis('off')
plt.figure()
plt.imshow(pred_Y[ii, :, :, :])
plt.axis('off')

 

你可能感兴趣的:(神经网络,小石的机器学习专栏)