Unet是受到FCN启发针对医学图像做语义分割, 且可以利用少量的数据学习到一个对边缘提取十分鲁棒的模型, 在生物医学图像分割领域有很大作用。
网络结构如下图:
如上图其结构如英文字母u,所以被命名为unet。其建立在FCN的架构上, 首先是从左侧输入开始的一系列卷积层,这里主要有5层,目的是用来提取图片的特征, 这里可以使用vgg或者resnet等经典的特征提取网络。然后是右侧的结构, 首先从最下层开始将提取出的特征进行上采样,上采样后的特征与其上一层的特征的形状相同, 然后将两个特征聚合在一起,并且添加卷积层进行通道数的缩减,然后对缩减后的特征进行上采样, 并重复之前的操作。最后上采样到与原图的形状相同的时候,再添加一个对每个像素点进行分类的卷积层。
keras实现如下:
def Unet(input_shape=(256,256,3),num_class=1):
input= Input(input_shape)
feat1,feat2,feat3,feat4,feat5 = VGG16(input)
channels = [64, 128, 256, 512]
P5_up = UpSampling2D(size=(2,2))(feat5)
P4 = Concatenate(axis=3)([feat4,P5_up])
P4 = Conv2D(channels[3],3,activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02))(P4)
P4 = Conv2D(channels[3],3,activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02))(P4)
P4_up = UpSampling2D(size=(2,2))(P4)
P3 = Concatenate(axis=3)([P4_up,feat3])
P3 = Conv2D(channels[2], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P3)
P3 = Conv2D(channels[2], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P3)
P3_up = UpSampling2D(size=(2, 2))(P3)
P2 = Concatenate(axis=3)([P3_up, feat2])
P2 = Conv2D(channels[1], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P2)
P2 = Conv2D(channels[1], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P2)
P2_up = UpSampling2D(size=(2, 2))(P2)
P1 = Concatenate(axis=3)([P2_up, feat1])
P1 = Conv2D(channels[0], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P1)
P1 = Conv2D(channels[0], 3, activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02))(P1)
P1 = Conv2D(num_class,1,activation="sigmoid")(P1)
model = Model(inputs=input,outputs=P1)
return model
首先是第5行调用的VGG16方法,这里是对应网络结构左侧的内容。这里使用VGG网络来做特征提取。 其返回了5个特征层的特征。
然后是第7行,这里定义了上采样后每层的特征数。然后是第9行,先对最底层的特征(feat5) 进行上采样;然后是第11行使用Concatenate网络将fea5上采样的结果(P5_up)与feat4进行连接; 然后是第13行和第14行使用两个3*3的卷积来进行通道数的缩减。
然后是一步步上采样和卷积直到第37行,用一个1*1的卷积来做像素点的分类。 最后是第39行根据上述的结构来创建模型。
在上文提到了特征提取是使用的是一个名为VGG16的方法,该方法内容如下:
def VGG16(input):
x = Conv2D(64,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
name="block1_conv1")(input)
x = Conv2D(64,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
name="block1_conv2")(x)
feat1 = x
x = MaxPooling2D((2,2),strides=(2,2),name="block1_pool")(x)
x = Conv2D(128,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
name="block2_conv1")(x)
x = Conv2D(128,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
name="block2_conv2")(x)
feat2 = x
x = MaxPooling2D((2,2),strides=(2,2),name="block2_pool")(x)
x = Conv2D(256,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
name="block3_conv1")(x)
x = Conv2D(256,(3,3),activation="relu",padding="same",kernel_initializer=RandomNormal(stddev=0.02),
name="block3_conv2")(x)
x = Conv2D(256, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block3_conv3")(x)
feat3 = x
x = MaxPooling2D((2, 2), strides=(2, 2), name="block3_pool")(x)
x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block4_conv1")(x)
x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block4_conv2")(x)
x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block4_conv3")(x)
feat4 = x
x = MaxPooling2D((2, 2), strides=(2, 2), name="block4_pool")(x)
x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block5_conv1")(x)
x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block5_conv2")(x)
x = Conv2D(512, (3, 3), activation="relu", padding="same", kernel_initializer=RandomNormal(stddev=0.02),
name="block5_conv3")(x)
feat5 = x
return feat1,feat2,feat3,feat4,feat5
VGG的实现实际很简单,首先是两个3*3的卷积,这两个卷积的结果作为feat1; 然后是一个最大池化层,池化后再接两个卷积层,卷积的结果作为feat2,然后一直重复这个操作, 直到feat5。最后再将feat1到feat都返回。
这里的模型是使用keras实现的,keras模型的训练很简单,只需要调用fit方法 或fit_generator方法便可。
在训练前只需要处理好模型的输入与输出便可。语义分割需要的如下图所示:
上图是一个语义分割的数据集,左边的图片是数据集的输入,右边的是输出。 这是一个二分类的语义分割任务,右边白色的标签为1,黑色的标签为0。
主要训练代码如下:
def train():
#图片与标签路径
train_csv = "XXX"
train_image = "XXX"
#模型保存路径
checkpoint_path = "../savemodel/unet_tld2.{epoch:02d}-{val_loss:.2f}.h5"
#除了输入与输出
trains, vals, test = utils.get_HSAgenerator_all(train_csv, train_image,train_batch_size=8)
cb = [
ModelCheckpoint(checkpoint_path, verbose=0),
TensorBoard(log_dir="../logs/unet_tld2")
]
model = Unet(input_shape=(256, 256, 3))
lr = 1e-4
a = Adam(learning_rate=lr)
model.compile(loss=utils.lossdif2, optimizer=a, metrics="acc")
model.fit_generator(trains, steps_per_epoch=2625, epochs=50, validation_data=vals,
validation_steps=375, callbacks=cb)
训练的代码主要如上所示,这里主要是使用的fit_generator方法来训练。 他要求传入的输入是一个generator类,这个类可以使用yield关键字来快速实现。
然后是callbacks,这里定义了两个类:ModelCheckpoint和TensorBoard。 ModelCheckpoint的主要作用是在模型训练完成指定的epoch后保存模型,默认是每个批次都保存模型。 TensorBoard主要的作用是存储TensorBoard需要的数据。
最后是模型编译需要的一些参数。损失函数一般可以交叉熵函数,optimizer可以使用adam或 sgd(随机梯度下降)等,metrics这里使用的acc(准确率)。
这里的模型是使用keras实现的,keras模型的预测也很简单。主要代码如下:
def predict_hsa():
name = "XXX"
model_path = "XXX"
r = []
m = Unet(input_shape=(256, 256, 3))
m.load_weights(model_path)
threshold = 0.5
image = cv2.imread(name)
image = cv2.resize(image,(256,256))
p = m.predict(np.array([image]))
mask = p[0]
mask = np.where(mask >= threshold, 1, 0).astype(np.uint8)
mask = cv2.resize(mask, (512, 512))
utils.show(image,mask)
首先是第8行和第9行,这里先定义了deeplab模型,然后加载训练好的模型。然后是第12行到 第14行,这里主要是读取图片并resize成模型需求的形状。然后是第16行调用predict方法进行预测,最 后是第17行到第20行,这里在处理模型预测结果并显示。
模型的预测效果主要如下: