这篇文章理解自知乎上两篇文章:
- 带你理解CycleGAN,并用TensorFlow轻松实现
- 可能是近期最好玩的深度学习模型:CycleGAN的原理与实验详解
GAN 补充
深度生成模型的分类树如下:
可以根据极大似然原理学习的深度生成模型,根据如何表示或预估概率,可以分为显式密度模型和隐式密度模型。显式密度模型可以构建一个明确的密度模型,p(x;θ),因此可以求得使可能性最大的 θ 值。显式密度模型又分为易解决的和不易解决的(需要使用近似法求最大化可能性的 θ)。对于隐式密度模型,则没有明确表示数据空间的概率分布,相反,该模型提供了一些与该概率分布间接相互作用的方式——生成样本,即定义一种在没有任何输入的情况下,通过随机转换现有样本,以便获取另一个服从同一分布的样本的方法。GAN 即属于隐式密度模型,它直接从模型表示的分布中采样,而非使用马尔可夫链。
GAN 核心原理的数学描述为:
简单分析一下这个公式:
- 整个式子由两项构成。x 表示真实图片,z 表示输入 G 网络的噪声,而 G(z) 表示 G 网络生成的图片。
- D(x) 表示 D 网络判断真实图片是否真实的概率(因为x就是真实的,所以对于 D 来说,这个值越接近1越好)。而 D(G(z)) 是 D 网络判断 G 生成的图片的是否真实的概率。
- G 的目的:G 希望自己生成的图片“越接近真实越好”。也就是说,G 希望 D(G(z)) 尽可能得大,这时 V(D, G) 会变小。因此式子对于 G 来说是求最小(min_G)。
- D的目的:D 的能力越强,D(x) 应该越大,D(G(x)) 应该越小,这时 V(D,G) 会变大。因此式子对于 D 来说是求最大(max_D)。
用随机梯度下降法训练 D 和 G 的算法为:
第一步训练 D,D 是希望 V(G, D) 越大越好,所以是加上梯度(ascending)。第二步训练 G 时,V(G, D) 越小越好,所以是减去梯度(descending)。整个训练过程交替进行。
CycleGAN 原理
CycleGAN的原理可以概述为:将一类图片转换成另一类图片。也就是说,现在有两个样本空间,X 和 Y,我们希望把 X 空间中的样本转换成 Y 空间中的样本。因此,实际的目标就是学习从 X 到 Y 的映射(设这个映射为 F),F 就对应着 GAN 中的生成器,F 可以将 X 中的图片 x 转换为 Y 中的图片 F(x)。对于生成的图片,我们还需要 GAN 中的判别器来判别它是否为真实图片,由此构成对抗生成网络。设这个判别器为 DY。这样的话,根据这里的生成器和判别器,我们就可以构造一个 GAN 损失,表达式为:
这个损失实际上和原始的 GAN 损失是一模一样的,但单纯的使用这一个损失是无法进行训练的。原因在于,映射 F 完全可以将所有 x 都映射为 Y 空间中的同一张图片,使损失无效化。对此,作者又提出了所谓的循环一致性损失(cycle consistency loss)。再假设一个映射 G,它可以将 Y 空间中的图片 y 转换为 X 中的图片 G(y)。CycleGAN 同时学习 F 和 G 两个映射,并要求 F(G(y)) ≈ y,以及 G(F(x)) ≈ x。也就是说,将 X 的图片转换到 Y 空间后,应该还可以转换回来。这样就杜绝模型把所有 X 的图片都转换为 Y 空间中的同一张图片了。根据 F(G(y)) ≈ y 和 G(F(x)) ≈ x,循环一致性损失就定义为:
同时,为 G 也引入一个判别器 DX,由此可以同样定义一个 GAN 的损失 LGAN(G,DX,X,Y)。最终的损失就由三部分组成:
CycleGAN 的结构示意图如下:
从上图可以了解 CycleGAN 的运作过程:两个输入被传递到对应的鉴别器(一个是对应于该域的原始图像,另一个是通过生成器产生的图像),并且鉴别器的任务是区分它们,识别出生成器输出的生成图像,并拒绝此生成图像。生成器想要确保这些图像被鉴别器接受,所以它将尝试生成与 DB 类中原始图像非常接近的新图像。事实上,在生成器分布与所需分布相同时,生成器和鉴别器之间实现了纳什均衡(Nash equilibrium)。
CycleGAN 的灵活性在于不需要提供从源域到目标域的配对转换例子就可以训练。比如,我们希望训练一个将白天的照片转换为夜晚的模型。如果使用pix2pix模型,那么我们必须在搜集大量地点在白天和夜晚的两张对应图片,而使用CycleGAN只需同时搜集白天的图片和夜晚的图片,不必满足对应关系。因此CycleGAN的用途要比pix2pix更广泛,利用CycleGAN就可以做出更多有趣的应用。
CycleGAN 实现
一、构建生成器
生成器的结构如下:
生成器由三部分组成:编码器、转换器、解码器。
编码
第一步是利用卷积网络从输入图像中提取特征。整个编码过程,将 DA 域中一个尺寸为 [256,256,3] 的图像,输入到设计的编码器中,获得了尺寸为 [64,64,256] 的输出 OAenc。
转换
这些网络层的作用是组合图像的不同相近特征,然后基于这些特征,确定如何将图像的特征向量 OAenc 从 DA 域转换为 DB 域的特征向量。因此,作者使用了 6 层 Resnet 模块。OBenc 表示该层的最终输出,尺寸为 [64,64,256],这可以看作是 DB 域中图像的特征向量。
一个 Resnet 模块是一个由两个卷积层组成的神经网络层,其中部分输入数据直接添加到输出。这样做是为了确保先前网络层的输入数据信息直接作用于后面的网络层,使得相应输出与原始输入的偏差缩小,否则原始图像的特征将不会保留在输出中且输出结果会偏离目标轮廓。这个任务的一个主要目标是保留原始图像的特征,如目标的大小和形状,因此残差网络非常适合完成这些转换。Resnet 模块的结构如下所示:
解码
解码过程与编码方式完全相反,从特征向量中还原出低级特征,这是利用了反卷积层(deconvolution)来完成的。最后,我们将这些低级特征转换得到一张在DB域中的图像,得到一个大小为 [256,256,3] 的生成图像 genB。
二、构建鉴别器
鉴别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。鉴别器的结构如下所示:
鉴别器本身就属于卷积网络,需要从图像中提取特征;然后是确定这些特征是否属于该特定类别,使用一个产生一维输出的卷积层来完成这个任务。
至此,已经完成该模型的两个主要组成部分,即生成器和鉴别器。由于要使这个模型可以从 A→B 和 B→A 两个方向工作,所以设置了两个生成器,即生成器 A→B 和生成器 B→A,以及两个鉴别器,即鉴别器 A 和鉴别器 B。
三、建立模型
在定义损失函数前,先定义基础输入变量,来构建模型:
input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A")
input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B")
同时定义模型如下:
gen_B = build_generator(input_A, name="generator_AtoB")
gen_A = build_generator(input_B, name="generator_BtoA")
dec_A = build_discriminator(input_A, name="discriminator_A")
dec_B = build_discriminator(input_B, name="discriminator_B")
dec_gen_A = build_discriminator(gen_A, "discriminator_A")
dec_gen_B = build_discriminator(gen_B, "discriminator_B")
cyc_A = build_generator(gen_B, "generator_BtoA")
cyc_B = build_generator(gen_A, "generator_AtoB")
gen 表示使用相应的生成器后生成的图像,dec 表示在将相应输入传递到鉴别器后做出的判断。因此:
- gen_A 是生成器 B2A 根据真 B 生成的假 A,
gen_B 是生成器 A2B 根据真 A 生成的假 B; - dec_A 是鉴别器 A 对真 A 的鉴别结果,
dec_B 是鉴别器 B 对真 B 的鉴别结果; - dec_gen_A 是鉴别器 A 对 gen_A 的鉴别结果,
dec_gen_B 是鉴别器 B 对 gen_B 的鉴别结果; - cyc_A 是生成器 B2A 根据 gen_B 生成的假 A,
cyc_B 是生成器 A2B 根据 gen_A 生成的假 B.
四、损失函数
现在我们有两个生成器和两个鉴别器。我们要按照实际目的来设计损失函数。损失函数应该包括如下四个部分:
- 鉴别器必须允许所有相应类别的原始图像,即对应输出置 1;
- 鉴别器必须拒绝所有想要愚弄过关的生成图像,即对应输出置 0;
- 生成器必须使鉴别器允许通过所有的生成图像,来实现愚弄操作;
- 所生成的图像必须保留有原始图像的特性,所以如果我们使用生成器 GeneratorA→B 生成一张假图像,那么要能够使用另一个生成器 GeneratorB→A 来努力恢复成原始图像。此过程必须满足循环一致性。
鉴别器损失
通过训练鉴别器 A,使其对真 A 的鉴别输出接近于1,鉴别器 B 也是如此。因此,鉴别器 A 的训练目标为最小化 (DiscriminatorA(a)−1)2 的值,鉴别器 B 也是如此。
另外,由于鉴别器应该能够区分生成图像和原始图像,所以在处理生成图像时期望输出为 0,即鉴别器 A 要最小化 (DiscriminatorA(GeneratorB→A(b)))2 的值。
d_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_A,1))
d_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_B,1))
d_loss_A_2 = tf.reduce_mean(tf.square(dec_gen_A))
d_loss_B_2 = tf.reduce_mean(tf.square(dec_gen_B))
d_loss_A = (d_loss_A_1 + d_loss_A_2) / 2
d_loss_B = (d_loss_B_1 + d_loss_B_2) / 2
生成器损失
最终生成器应该使得鉴别器对生成图像的输出值尽可能接近 1。故生成器想要最小化 (DiscriminatorB(GeneratorA→B(a))−1)2。对应代码为:
g_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_gen_B,1))
g_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))
循环损失
最后一个重要参数为循环丢失(cyclic loss),能判断用另一个生成器得到的生成图像与原始图像的差别。因此原始图像和循环图像之间的差异应该尽可能小:
cyc_loss = tf.reduce_mean(tf.abs(input_A - cyc_A)) + tf.reduce_mean(tf.abs(input_B - cyc_B))
所以完整的生成器损失为:
g_loss_A = g_loss_A_1 + 10 * cyc_loss
g_loss_B = g_loss_B_1 + 10 * cyc_loss
cyc_loss 的乘法因子设置为 10,说明循环损失比鉴别损失更重要。
五、训练模型
定义好损失函数,接下来只需要训练模型来最小化损失函数。
d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)
训练过程如下:
for epoch in range(0,100):
# Define the learning rate schedule. The learning rate is kept
# constant upto 100 epochs and then slowly decayed
if(epoch < 100) :
curr_lr = 0.0002
else:
curr_lr = 0.0002 - 0.0002*(epoch-100)/100
# Running the training loop for all batches
for ptr in range(0,num_images):
# Train generator G_A->B
_, gen_B_temp = sess.run([g_A_trainer, gen_B],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
# We need gen_B_temp because to calculate the error in training D_B
_ = sess.run([d_B_trainer],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
# Same for G_B->A and D_A as follow
_, gen_A_temp = sess.run([g_B_trainer, gen_A],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
_ = sess.run([d_A_trainer],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
在训练函数中可以看到,在训练时需要不断调用不同鉴别器和生成器。为了训练模型,需要输入训练图像和选择优化器的学习率。由于 batch_size 设置为1,所以 num_batches 等于 num_images。
我们已经完成了模型构建,下面是模型中一些默认超参数。
生成图像库
计算每个生成图像的鉴别器损失是不可能的,因为会耗费大量的计算资源。为了加快训练,我们存储了之前每个域的所有生成图像,并且每次仅使用一张图像来计算误差。首先,逐个填充图像库使其完整,然后随机将某个库中的图像替换为最新的生成图像,并使用这个替换图像来作为该步的训练。