UNet是一个经典的网络设计方式,在图像分割任务中具有大量的应用。也有许多新的方法在此基础上进行改进,融合更加新的网络设计理念,在小批量数据集上也经常能取得不错的效果。
关于Unet系列模型的介绍可以参考文章:https://zhuanlan.zhihu.com/p/57530767
该文章介绍了U-Net、3D U-Net、TernausNet、Res-UNet 和Dense U-Net、MultiResUNet、R2U-Net、Attention UNet 等模型,至于这些方法的有效性,我们还需要在后续实验中进行验证。
#coding=utf-8
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import argparse
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import img_to_array
from keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import LabelEncoder
from keras.models import Model
from keras.layers.merge import concatenate
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import random
import os
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
seed = 7
np.random.seed(seed)
def load_img(path, grayscale=False):
if grayscale:
img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
else:
img = cv2.imread(path)
img = np.array(img,dtype="float") / 255.0
return img
filepath ='./unet_train/building/'
def get_train_val(val_rate = 0.25):
train_url = []
train_set = []
val_set = []
for pic in os.listdir(filepath + 'src'):
train_url.append(pic)
random.shuffle(train_url)
total_num = len(train_url)
val_num = int(val_rate * total_num)
for i in range(len(train_url)):
if i < val_num:
val_set.append(train_url[i])
else:
train_set.append(train_url[i])
return train_set,val_set
train_set,val_set = get_train_val()
# 查看分割后的数据集大小
len(train_set), len(val_set)
# data for training
def generateData(batch_size,data=[]):
#print 'generateData...'
while True:
train_data = []
train_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
img = load_img(filepath + 'src/' + url)
img = img_to_array(img)
train_data.append(img)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label)
train_label.append(label)
if batch % batch_size==0:
#print 'get enough bacth!\n'
train_data = np.array(train_data)
train_label = np.array(train_label)
yield (train_data,train_label)
train_data = []
train_label = []
batch = 0
# data for validation
def generateValidData(batch_size,data=[]):
#print 'generateValidData...'
while True:
valid_data = []
valid_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
img = load_img(filepath + 'src/' + url)
img = img_to_array(img)
valid_data.append(img)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label)
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label)
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BZqxyvc3-1616036258347)(https://i.loli.net/2021/03/18/8koDuElSyvxc7Hd.png)]
# 定义一个model式模型
img_w = 256
img_h = 256
def unet():
inputs = Input((img_w, img_h, 3))
# 卷积+池化
conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2), )(conv3)
conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
# 只卷积不池化
conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)
#开始上采样过程
up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)
up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)
up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)
up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)
# 由于我们要训练二分类模型,所以使用simoid函数,多分类模型则使用softmax函数
conv10 = Conv2D(1, (1, 1), activation="sigmoid")(conv9)
#conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)
model = Model(inputs=inputs, outputs=conv10)
model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
unet_model = unet()
unet_model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 256, 256, 32) 896 input_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 256, 256, 32) 9248 conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 128, 128, 32) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 128, 128, 64) 18496 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 128, 128, 64) 36928 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 64, 64, 64) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 128) 73856 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 64, 64, 128) 147584 conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 32, 32, 128) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 32, 256) 295168 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 256) 590080 conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 16, 16, 256) 0 conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 16, 16, 512) 1180160 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 16, 16, 512) 2359808 conv2d_9[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 32, 32, 512) 0 conv2d_10[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 32, 32, 768) 0 up_sampling2d_1[0][0]
conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 256) 1769728 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 32, 32, 256) 590080 conv2d_11[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 64, 64, 256) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 64, 64, 384) 0 up_sampling2d_2[0][0]
conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 64, 64, 128) 442496 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 64, 64, 128) 147584 conv2d_13[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 128, 128, 128 0 conv2d_14[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 128, 128, 192 0 up_sampling2d_3[0][0]
conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 128, 128, 64) 110656 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 128, 128, 64) 36928 conv2d_15[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 256, 256, 64) 0 conv2d_16[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 256, 256, 96) 0 up_sampling2d_4[0][0]
conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 256, 256, 32) 27680 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 256, 256, 32) 9248 conv2d_17[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 256, 256, 1) 33 conv2d_18[0][0]
==================================================================================================
Total params: 7,846,657
Trainable params: 7,846,657
Non-trainable params: 0
__________________________________________________________________________________________________
EPOCHS = 10
BS = 16
#data_shape = 360*480
img_w = 256
img_h = 256
#有一个为背景
#n_label = 4+1
n_label = 1
classes = [0. , 1., 2., 3. , 4.]
labelencoder = LabelEncoder()
labelencoder.fit(classes)
# 定义模型存储位置
modelcheck = ModelCheckpoint("unet_buildings.h5",monitor='val_acc',save_best_only=True,mode='max')
callable = [modelcheck]
# 划分数据集
train_set,val_set = get_train_val()
train_numb = len(train_set)
valid_numb = len(val_set)
print ("the number of train data is",train_numb)
print ("the number of val data is",valid_numb)
the number of train data is 7500
the number of val data is 2500
# 训练模型
H = unet_model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,
validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)
/home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(generator=
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on U-Net Satellite Seg")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.show()
# 保存图片
plt.savefig("unet_building.png")
from keras.models import load_model
# 定义预测图片名称及路径
TEST_SET = ['1.png','2.png', '3.png']
image_dir = "./test_image/"
# 加载训练好的模型
print("[INFO] loading network...")
model = load_model("unet_buildings20.h5")
# 设置长宽、滑动步长
image_size = 256
stride = 256
for n in range(len(TEST_SET)):
path = TEST_SET[n]
#load the image
image = cv2.imread(image_dir + path)
h,w,_ = image.shape
# 对图像进行填充,使图像大小变成256的整数倍
padding_h = (h//stride + 1) * stride
padding_w = (w//stride + 1) * stride
padding_img = np.zeros((padding_h, padding_w, 3),dtype=np.uint8)
padding_img[0:h,0:w,:] = image[:,:,:]
#padding_img = padding_img.astype("float") / 255.0
padding_img = img_to_array(padding_img)
print('src:',padding_img.shape)
mask_whole = np.zeros((padding_h, padding_w),dtype=np.uint8)
print(padding_h, padding_w)
for i in range(padding_h//stride):
for j in range(padding_w//stride):
crop = padding_img[i*stride:(i*stride+image_size), j*stride:(j*stride+image_size), :3]
ch,cw,_ = crop.shape
if ch != 256 or cw != 256:
print('invalid size!')
continue
# 对分割好的图片进行预测
crop = np.expand_dims(crop, axis=0)
pred = model.predict(crop,verbose=2)
#print (np.unique(pred))
pred = pred.reshape((256,256)).astype(np.uint8)
#print 'pred:',pred.shape
mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]
cv2.imwrite('./predict/20pre'+str(n+1)+'.png',mask_whole[0:h,0:w])
print("第%d" %n + "张图片预测完成")
visualize = np.zeros((padding_h, padding_w)).astype(np.uint8)
visualize = mask_whole *255
# (255, 255, 0)是黄色
result = np.array([visualize, visualize, np.zeros((padding_h, padding_w),dtype=np.uint8)])
result = result.transpose(1, 2, 0)
print(result.shape)
#... get array s.t. arr.shape = (w, h, 3)
img = Image.fromarray(result).convert('RGB') # 将数组转化回图片
img.save('./predict/20pre'+str(n+1)+'.tif') #
U-net模型能够充分利用不同层次的图像特征,使得它具有良好的学习和表示能力,但它依赖于多级级联卷积神经网络。这些级联框架提取感兴趣的区域并做出密集的预测。这种方法在重复提取低层特征时会导致计算资源的过度和冗余使用。
U-net系列网络在小目标分割任务上性能一直表现不错,笔者猜测这恰恰也得益于其对于底层特征(如颜色、纹理等)的重复计算,如同注意力机制一般,对底层特征的关注更多,更有利于提取小目标信息。
对于遥感影像而言,大多数目标地类都是小目标,可能这就是为什么Unet网络在Kaggle遥感影像分割比赛中那么受人欢迎叭??
由于硬盘和内存限制,笔者仅使用了1w张图片进行训练,所以精度不是很高,局部效果看起来还凑合吧!我只能说:懂得都懂!
此外,由于U-net模型在池化时会有一定的精度损失,所以这里的建筑物边界都比较模糊,或者说是平滑,师兄说可以利用FPN或者其他正则化方法来进行处理,也许吧。