目录
介绍
从零开始的CycleGAN
加载数据集
构建鉴别器
构建残差块
构建生成器
构建CycleGAN
训练CycleGAN
评估绩效
结论
在本系列文章中,我们将展示一个基于循环一致对抗网络 (CycleGAN)的移动图像到图像转换系统。我们将构建一个CycleGAN,它可以执行不成对的图像到图像的转换,并向您展示一些有趣但具有学术深度的例子。我们还将讨论如何将这种使用TensorFlow和Keras构建的训练有素的网络转换为TensorFlow Lite并用作移动设备上的应用程序。
我们假设您熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。欢迎您下载项目代码。
在本系列的前一篇文章中,我们训练和评估了一个使用基于U-Net的生成器的CycleGAN。在本文中,我们将使用基于残差的生成器实现CycleGAN。
最初的CycleGan最初是使用基于残差的生成器构建的。让我们从头开始实现这种类型的CycleGAN。我们将构建网络并训练它使用带有和不带有伪影的眼底数据集来减少眼底图像中的伪影。
网络将有伪影的眼底图像转换为没有伪影的眼底图像,反之亦然,如上所示。
CycleGAN 设计将包括以下步骤:
在开始加载数据之前,让我们导入一些必要的库和包。
#the necessary imports
from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from matplotlib import pyplot
与我们在上一篇文章中所做的相反,这次我们将使用本地机器(而不是Google Colab)来训练CycleGAN。因此,应首先下载和处理眼底数据集。我们将使用Jupyter Notebook和TensorFlow来构建和训练这个网络。
from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed
# load all images in a directory into memory
def load_images(path, size=(256,256)):
data_list = list()
# enumerate filenames in directory, assume all are images
for filename in listdir(path):
# load and resize the image
pixels = load_img(path + filename, target_size=size)
# convert to numpy array
pixels = img_to_array(pixels)
# store
data_list.append(pixels)
return asarray(data_list)
# dataset path
path = r'C:/Users/abdul/Desktop/ContentLab/P3/Fundus/'
# load dataset A
dataA1 = load_images(path + 'trainA/')
dataAB = load_images(path + 'testA/')
dataA = vstack((dataA1, dataAB))
print('Loaded dataA: ', dataA.shape)
# load dataset B
dataB1 = load_images(path + 'trainB/')
dataB2 = load_images(path + 'testB/')
dataB = vstack((dataB1, dataB2))
print('Loaded dataB: ', dataB.shape)
# save as compressed numpy array
filename = 'Artifcats.npz'
savez_compressed(filename, dataA, dataB)
print('Saved dataset: ', filename)
加载数据后,就可以创建一个显示一些训练图像的函数了:
# load and plot the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('Artifacts.npz')
dataA, dataB = data['arr_0'], data['arr_1']
print('Loaded: ', dataA.shape, dataB.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(dataA[i].astype('uint8'))
# plot target image
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(dataB[i].astype('uint8'))
pyplot.show()
正如我们之前讨论过的,鉴别器是一个由许多卷积层以及LeakReLU和实例归一化层组成的CNN 。
def define_discriminator(image_shape):
# weight initialization
init = RandomNormal(stddev=0.02)
# source image input
in_image = Input(shape=image_shape)
# C64
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
d = LeakyReLU(alpha=0.2)(d)
# C128
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = InstanceNormalization(axis=-1)(d)
d = LeakyReLU(alpha=0.2)(d)
# C256
d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = InstanceNormalization(axis=-1)(d)
d = LeakyReLU(alpha=0.2)(d)
# C512
d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = InstanceNormalization(axis=-1)(d)
d = LeakyReLU(alpha=0.2)(d)
# second last output layer
d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
d = InstanceNormalization(axis=-1)(d)
d = LeakyReLU(alpha=0.2)(d)
# patch output
patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
# define model
model = Model(in_image, patch_out)
# compile model
model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
return model
一旦构建了鉴别器,我们就可以创建它的副本,以便我们有两个相同的鉴别器:DiscA和DiscB。
image_shape=(256,256,3)
DiscA=define_discriminator(image_shape)
DiscB=define_discriminator(image_shape)
DiscA.summary()
下一步是为我们的生成器创建残差块。该块是一组2D卷积层,其中每两层后跟一个实例归一化层。
# generator a resnet block
def resnet_block(n_filters, input_layer):
# weight initialization
init = RandomNormal(stddev=0.02)
# first layer convolutional layer
g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
g = InstanceNormalization(axis=-1)(g)
g = Activation('relu')(g)
# second convolutional layer
g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
g = InstanceNormalization(axis=-1)(g)
# concatenate merge channel-wise with input layer
g = Concatenate()([g, input_layer])
return g
残差块的输出将通过生成器的最后一部分(解码器),在那里图像将被上采样并调整到其原始大小。由于编码器尚未定义,我们将构建一个函数来定义解码器和编码器部分并将它们连接到残差块。
# define the generator model
def define_generator(image_shape, n_resnet=9):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=image_shape)
g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
g = InstanceNormalization(axis=-1)(g)
g = Activation('relu')(g)
# d128
g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g = InstanceNormalization(axis=-1)(g)
g = Activation('relu')(g)
# d256
g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g = InstanceNormalization(axis=-1)(g)
g = Activation('relu')(g)
# R256
for _ in range(n_resnet):
g = resnet_block(256, g)
# u128
g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g = InstanceNormalization(axis=-1)(g)
g = Activation('relu')(g)
# u64
g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
g = InstanceNormalization(axis=-1)(g)
g = Activation('relu')(g)
g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
g = InstanceNormalization(axis=-1)(g)
out_image = Activation('tanh')(g)
# define model
model = Model(in_image, out_image)
return model
现在,我们定义生成器genA和genB。
genA=define_generator(image_shape, 9)
genB=define_generator(image_shape, 9)
定义了生成器和鉴别器后,我们现在可以构建整个CycleGAN模型并设置其优化器和其他学习参数。
#define a composite model
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
# ensure the model we're updating is trainable
g_model_1.trainable = True
# mark discriminator as not trainable
d_model.trainable = False
# mark other generator model as not trainable
g_model_2.trainable = False
# discriminator element
input_gen = Input(shape=image_shape)
gen1_out = g_model_1(input_gen)
output_d = d_model(gen1_out)
# identity element
input_id = Input(shape=image_shape)
output_id = g_model_1(input_id)
# forward cycle
output_f = g_model_2(gen1_out)
# backward cycle
gen2_out = g_model_2(input_id)
output_b = g_model_1(gen2_out)
# define model graph
model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
# define optimization algorithm configuration
opt = Adam(lr=0.0002, beta_1=0.5)
# compile model with weighting of least squares loss and L1 loss
model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
return model
现在让我们定义两个模型(A和B),其中一个将眼底图像伪影转换为无伪影眼底(AtoB),另一个将无伪影转换为伪影眼底图像(BtoA)。
comb_modelA=define_composite_model(genA,DiscA,genB,image_shape)
comb_modelB=define_composite_model(genB,DiscB,genA,image_shape)
现在我们的模型已经完成,我们将创建一个训练函数,该函数定义训练参数并计算生成器和鉴别器的损失,以及在训练期间更新权重。此功能将按如下方式操作:
# train the cycleGAN model
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):
# define properties of the training run
n_epochs, n_batch, = 30, 1
# determine the output square shape of the discriminator
n_patch = d_model_A.output_shape[1]
# unpack dataset
trainA, trainB = dataset
# prepare image pool for fakes
poolA, poolB = list(), list()
# calculate the number of batches per training epoch
bat_per_epo = int(len(trainA) / n_batch)
# calculate the number of training iterations
n_steps = bat_per_epo * n_epochs
# manually enumerate epochs
for i in range(n_steps):
# select a batch of real samples
X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)
X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
# generate a batch of fake samples
X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)
X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)
# update fakes from pool
X_fakeA = update_image_pool(poolA, X_fakeA)
X_fakeB = update_image_pool(poolB, X_fakeB)
# update generator B->A via adversarial and cycle loss
g_loss2, _, _, _, _ = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
# update discriminator for A -> [real/fake]
dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
# update generator A->B via adversarial and cycle loss
g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
# update discriminator for B -> [real/fake]
dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
# summarize performance
print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
# evaluate the model performance every so often
if (i+1) % (bat_per_epo * 1) == 0:
# plot A->B translation
summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
# plot B->A translation
summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
if (i+1) % (bat_per_epo * 5) == 0:
# save the models
save_models(i, g_model_AtoB, g_model_BtoA)
下面是一些在训练过程中会用到的函数。
#load and prepare training images
def load_real_samples(filename):
# load the dataset
data = load(filename)
# unpack arrays
X1, X2 = data['arr_0'], data['arr_1']
# scale from [0,255] to [-1,1]
X1 = (X1 - 127.5) / 127.5
X2 = (X2 - 127.5) / 127.5
return [X1, X2]
# The generate_real_samples() function below implements this
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
# choose random instances
ix = randint(0, dataset.shape[0], n_samples)
# retrieve selected images
X = dataset[ix]
# generate 'real' class labels (1)
y = ones((n_samples, patch_shape, patch_shape, 1))
return X, y
# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):
# generate fake instance
X = g_model.predict(dataset)
# create 'fake' class labels (0)
y = zeros((len(X), patch_shape, patch_shape, 1))
return X, y
# update image pool for fake images
def update_image_pool(pool, images, max_size=50):
selected = list()
for image in images:
if len(pool) < max_size:
# stock the pool
pool.append(image)
selected.append(image)
elif random() < 0.5:
# use image, but don't add it to the pool
selected.append(image)
else:
# replace an existing image and use replaced image
ix = randint(0, len(pool))
selected.append(pool[ix])
pool[ix] = image
return asarray(selected)
我们添加了更多功能来保存最佳模型并可视化眼底图像中伪影减少的性能。
def save_models(step, g_model_AtoB, g_model_BtoA):
# save the first generator model
filename1 = 'g_model_AtoB_%06d.h5' % (step+1)
g_model_AtoB.save(filename1)
# save the second generator model
filename2 = 'g_model_BtoA_%06d.h5' % (step+1)
g_model_BtoA.save(filename2)
print('>Saved: %s and %s' % (filename1, filename2))
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, trainX, name, n_samples=5):
# select a sample of input images
X_in, _ = generate_real_samples(trainX, n_samples, 0)
# generate translated images
X_out, _ = generate_fake_samples(g_model, X_in, 0)
# scale all pixels from [-1,1] to [0,1]
X_in = (X_in + 1) / 2.0
X_out = (X_out + 1) / 2.0
# plot real images
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(X_in[i])
# plot translated image
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(X_out[i])
# save plot to file
filename1 = '%s_generated_plot_%06d.png' % (name, (step+1))
pyplot.savefig(filename1)
pyplot.close()
train(DiscA, DiscB, genA, genB, comb_modelA, comb_modelB, dataset)
使用上述函数,我们对网络进行了30个epoch的训练。结果表明,我们的网络能够减少眼底图像中的伪影。
工件到无工件转换( AtoB )的结果如下所示:
还计算无伪影到伪影 (BtoA)眼底图像转换;这里有些例子。
正如AI先驱Yann LeCun谈到GAN时所说,“(这是)过去10年深度学习中最有趣的想法”。我们希望,通过这个系列,我们已经帮助您理解了为什么GAN是一些非常有趣的想法。我们知道您可能会发现系列中提出的概念有点沉重和模棱两可,但这完全没问题。CycleGAN在一次阅读中非常难以掌握,在你理解之前可以多读几遍这个系列。
https://www.codeproject.com/Articles/5304928/Building-a-Style-Transfer-CycleGAN-from-Scratch