训练代码:
'''
2021-4-25, edit by wyf, seg-train
python3.7, tensorflow2.0.0
'''
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
# 数据集路径
images = glob.glob('C:/Users/wyf1998/PycharmProjects/tf20/datasets/location_and_seg/images/*.jpg')
anno = glob.glob('C:/Users/wyf1998/PycharmProjects/tf20/datasets/location_and_seg/annotations/trimaps/*.png')
# 对图片进行乱序
np.random.seed(2021)
index = np.random.permutation(len(images))
images = np.array(images)[index]
anno = np.array(anno)[index]
# 创建数据集并划分训练集和测试集
dataset = tf.data.Dataset.from_tensor_slices((images, anno))
test_count = int(len(images)*0.2)
train_count = len(images) - test_count
dataset_train = dataset.skip(test_count)
dataset_test = dataset.take(test_count)
# 相关函数
@tf.function
def read_jpg(path):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
return img
@tf.function
def read_png(path):
img = tf.io.read_file(path)
img = tf.image.decode_png(img, channels=1)
return img
@tf.function
def nomal_img(input_images, input_anno):
input_images = tf.cast(input_images, tf.float32)
input_images = input_images/127.5 - 1
input_anno -= 1#将像素类别转为从0开始
return input_images, input_anno
def load_images(input_images_path, input_anno_path):
input_image = read_jpg(input_images_path)
input_anno = read_png(input_anno_path)
input_image = tf.image.resize(input_image, (224, 224))
input_anno = tf.image.resize(input_anno, (224, 224))
input_image, input_anno = nomal_img(input_image, input_anno)
return input_image, input_anno
dataset_train = dataset_train.map(load_images)
dataset_test = dataset_test.map(load_images)
BATCH_SIZE = 8
dataset_train = dataset_train.shuffle(buffer_size=100).batch(batch_size=BATCH_SIZE)
dataset_test = dataset_test.batch(batch_size=BATCH_SIZE)
print(dataset_train)
# 可视化
# for img, anno in dataset_train.take(1):
# plt.subplot(1, 2, 1)
# plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))
# plt.subplot(1, 2, 2)
# plt.imshow(tf.keras.preprocessing.image.array_to_img(anno[0]))
# plt.show()
# 使用VGG16网络,并微调
conv_base = tf.keras.applications.VGG16(weights='imagenet', input_shape=(224, 224, 3), include_top=False)
sub_model = tf.keras.models.Model(inputs=conv_base.input,
outputs=conv_base.get_layer('block5_conv3').output)
#获取网络中间层的输出
layer_names = ['block5_conv3', 'block4_conv3', 'block3_conv3', 'block5_pool']
layers_output = [conv_base.get_layer(layer_name).output for layer_name in layer_names]
multi_out_model = tf.keras.models.Model(inputs=conv_base.input, outputs=layers_output)
multi_out_model.trainable = False
inputs = tf.keras.layers.Input(shape=(224, 224, 3))
out_block5_conv3, out_block4_conv3, out_block3_conv3, out = multi_out_model(inputs)
x1 = tf.keras.layers.Conv2DTranspose(512,
kernel_size=(3, 3),
strides=(2, 2),
padding='same',
activation='relu')(out)
x1 = tf.keras.layers.Conv2D(512, 3, padding='same', activation='relu')(x1)
x2 = tf.add(x1, out_block5_conv3)
x2 = tf.keras.layers.Conv2DTranspose(512,
kernel_size=(3, 3),
strides=(2, 2),
padding='same',
activation='relu')(x2)
x2 = tf.keras.layers.Conv2D(512, 3, padding='same', activation='relu')(x2)
x3 = tf.add(x2, out_block4_conv3)
x3 = tf.keras.layers.Conv2DTranspose(256,
kernel_size=(3, 3),
strides=(2, 2),
padding='same',
activation='relu')(x3)
x3 = tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu')(x3)
x4 = tf.add(x3, out_block3_conv3)
x5 = tf.keras.layers.Conv2DTranspose(256, 3,
strides=(2, 2),
padding='same',
activation='relu')(x4)
pred = tf.keras.layers.Conv2DTranspose(3,
kernel_size=3,
strides=(2, 2),
padding='same',
activation='softmax')(x5)
model = tf.keras.models.Model(inputs=inputs, outputs=pred)
model.summary()
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
train_loss = tf.keras.metrics.Mean('train_loss')
train_acc = tf.keras.metrics.SparseCategoricalAccuracy('train_acc')
test_loss = tf.keras.metrics.Mean('test_loss')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy('test_acc')
def train_step(images, labels, model):
with tf.GradientTape() as t:
predict = model(images)
loss_step = loss_func(labels, predict)
grads = t.gradient(loss_step, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss(loss_step)
train_acc(labels, predict)
def test_step(images, labels, model):
predict = model(images)
loss_step = loss_func(labels, predict)
test_loss(loss_step)
test_acc(labels, predict)
def train(data_train, data_test, epoch_nums):
for epoch in range(1, epoch_nums):
print('step_per_epoch:{}'.format(train_count // BATCH_SIZE))
for (batch, (image, label)) in enumerate(data_train):
train_step(image, label, model)
print('epoch{}/{}, train_loss={:.3f}, train_acc={:.3f}'.format(epoch,
batch,
train_loss.result(),
train_acc.result()))
for image_, label_ in data_test:
test_step(image_, label_, model)
print('epoch{}--val-----, test_loss={:.3f}, test_acc={:.3f}'.format(epoch,
test_loss.result(),
test_acc.result()))
print('Epoch {}, loss= {}, acc= {}, test_loss= {}, test_acc= {}'.format(epoch,
train_loss.result(),
train_acc.result(),
test_loss.result(),
test_acc.result()))
train_loss.reset_states()
train_acc.reset_states()
test_loss.reset_states()
test_acc.reset_states()
if __name__ == '__main__':
train(data_train=dataset_train, data_test=dataset_test, epoch_nums=50)
model.save('my_model.h5')
测试代码:
'''
2021-4-25, edit by wyf,seg-test
python3.7, tensorflow2.0.0
'''
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
image_path = glob.glob('C:/Users/wyf1998/PycharmProjects/tf20/datasets/location_and_seg/images/*.jpg')
model = tf.keras.models.load_model('my_model.h5')
model.summary()
def read_jpg(path):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
return img
def nomal_img(input_images):
input_images = tf.cast(input_images, tf.float32)
input_images = input_images/255
return input_images
def load_images(input_images_path):
input_image = read_jpg(input_images_path)
input_image = tf.image.resize(input_image, (224, 224))
input_image = nomal_img(input_image)
return input_image
img = load_images(image_path[563])
plt.subplot(1, 2, 1)
plt.imshow(img)
img = tf.expand_dims(img, 0)
pred = model.predict(img)
print(pred)
pred = tf.argmax(pred, axis=-1)
pred = pred[..., tf.newaxis]
pred = tf.squeeze(pred, axis=0)
plt.subplot(1, 2, 2)
plt.imshow(pred)
plt.show()