Unet是一种U型网络,分为左右两部分卷积,左边为下采样提取高维特征,右边为上采样并与左侧融合实现图像分割。这里使用TensorFlow实现Unet网络,实现对遥感影像的道路分割。
训练数据:
标签图像:
Unet实现:
import tensorflow as tf
import numpy as np
import cv2
import glob
import itertools
class UNet:
def __init__(
self,
input_width,
input_height,
num_classes,
train_images,
train_instances,
val_images,
val_instances,
epochs,
lr,
lr_decay,
batch_size,
save_path
):
self.input_width = input_width
self.input_height = input_height
self.num_classes = num_classes
self.train_images = train_images
self.train_instances = train_instances
self.val_images = val_images
self.val_instances = val_instances
self.epochs = epochs
self.lr = lr
self.lr_decay = lr_decay
self.batch_size = batch_size
self.save_path = save_path
def leftNetwork(self, inputs):
x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(inputs)
o_1 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2, 2))(o_1)
x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
o_2 = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_2)
x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
o_3 = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_3)
x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
o_4 = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_4)
x = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)
o_5 = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)
return [o_1, o_2, o_3, o_4, o_5]
def rightNetwork(self, inputs):
c_1, c_2, c_3, c_4, o_5 = inputs
o_5 = tf.keras.layers.UpSampling2D((2, 2))(o_5)
x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(4)(c_4), o_5], axis=3)
x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(16)(c_3), x], axis=3)
x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(40)(c_2), x], axis=3)
x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(88)(c_1), x], axis=3)
x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
x = tf.keras.layers.Conv2D(self.num_classes, (1, 1), padding='valid')(x)
x = tf.keras.layers.Activation('softmax')(x)
return x
def build_model(self):
inputs = tf.keras.Input(shape=[self.input_height, self.input_width, 3])
left_output = self.leftNetwork(inputs)
right_output = self.rightNetwork(left_output)
model = tf.keras.Model(inputs=inputs, outputs=right_output)
return model
def train(self):
G_train = self.dataGenerator(model='training')
G_eval = self.dataGenerator(model='validation')
#model = self.build_model()
model = tf.keras.models.load_model('model.h5')
model.compile(
optimizer=tf.keras.optimizers.Adam(self.lr, self.lr_decay),
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.fit_generator(
G_train, 5, validation_data=G_eval, validation_steps=5, epochs=self.epochs
)
model.save(self.save_path)
def dataGenerator(self, model):
if model == 'training':
images = glob.glob(self.train_images + '*.jpg')
images.sort()
instances = glob.glob(self.train_instances + '*.png')
instances.sort()
zipped = itertools.cycle(zip(images, instances))
while True:
x_train = []
y_train = []
for _ in range(self.batch_size):
img, seg = next(zipped)
img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height)) / 255.0
seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), self.num_classes)
x_train.append(img)
y_train.append(seg)
yield np.array(x_train), np.array(y_train)
if model == 'validation':
images = glob.glob(self.val_images + '*.jpg')
images.sort()
instances = glob.glob(self.val_instances + '*.png')
instances.sort()
zipped = itertools.cycle(zip(images, instances))
while True:
x_eval = []
y_eval = []
for _ in range(self.batch_size):
img, seg = next(zipped)
img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height)) / 255.0
seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), self.num_classes)
x_eval.append(img)
y_eval.append(seg)
yield np.array(x_eval), np.array(y_eval)
训练脚本:
unet = UNet(
input_width=572,
input_height=572,
num_classes=2,
train_images='./datasets/train/images/',
train_instances='./datasets/train/instances/',
val_images='./datasets/validation/images/',
val_instances='./datasets/validation/instances/',
epochs=100,
lr=0.0001,
lr_decay=0.00001,
batch_size=100,
save_path='model.h5'
)
unet.train()
这里仅分割道路和背景,属于二分类,输出矩阵形状为2*388*388,进行100轮训练后保存模型进行推理验证。
推理脚本:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
model = tf.keras.models.load_model('model.h5')
img = '17.jpg'
img = cv2.resize(cv2.imread(img), (572, 572)) / 255.
img = np.expand_dims(img, 0)
pred = model.predict(img)
pred = np.argmax(pred[0], axis=-1)
pred[pred == 1] = 255
cv2.imwrite('result.jpg', pred)
plt.imshow(pred)
plt.show()
测试图像:
推理结果:
将推理结果与原始图像叠加显示:
import cv2
img_path = '17.jpg'
result_path = 'result.jpg'
img = cv2.imread(img_path)
height, width = img.shape[:2]
result = cv2.imread(result_path)
result = cv2.resize(result, (height, width), cv2.INTER_LINEAR)
result = cv2.Canny(result, 0, 255)
for i in range(height):
for j in range(width):
if result[i][j] == 255:
img[i][j] = [0, 0, 255]
cv2.imwrite('temp.jpg', result)
cv2.imwrite('out.jpg', img)