- 本文是 Make Your First GAN With PyTorch 的第 10 章,本书的介绍详见这篇文章。
- 本文是这本书正文的最后一章,后面我会补充完这本书的总结和附录。
前面,我们使用 GAN生成了 MNIST 手写体数字,使用 CelebA 训练后生成了大量的人脸图像。
考虑一下,如果 GAN 能产生多样图像的同时,可以限制生成的图像为训练数据的某一类将会很有用。
- 比如,我们能要求 GAN 来产生数字 3 的不同图像;或者,如果使用的脸部图像训练数据中有情感的标注,就能要求 GAN 仅仅产生 “快乐” 的脸。
能实现生成指定条件的 GAN 网络架构,应该是什么形式的呢?
首先,如果想要训练好的 GAN 的生成器输出给定类别的图像,需要告诉网络我们想要哪种类别,这意味着需要在随机种子之外,需要将所需的 “类别” 信息作为生成器的输入。
更进一步,之前的鉴别器仅用于区分真实图像和生成的图像,现在则需要具备同时区分图像及其类别标签的能力。否则,将不能提供给生成器任何有关图像标签的反馈。
基于上述两点,下图是 条件 GAN(Conditional GAN) 的架构:
- 其中,最关键的变化是生成器和鉴别器在输入图像数据时,都有了类别标签。
- 以之前 全连接 MNIST GAN 的代码为基础,通过适当修改获得这个 GAN。
我们需要通过同时输入图像像素值和类别标签信息来更新鉴别器。
所以,我们拓展 forward()
函数,使之能同时将图像张量和类别张量合并并同时输入。其中标签张量是 独热张量(one-hot tensor),在数据集类中也需要备好。
def forward(self, image_tensor, label_tensor):
# 将图像张量和和标签张量进行合并
inputs = torch.cat((image_tensor, label_tensor))
return self.model(inputs)
- 代码中,
torch.cat()
函数是使用一个张量来扩展另一个张量。- 其中,图像张量的长度为 784,而标签张量的长度为 10,故结合体的长度为 794。
由于扩展了输入尺寸,所以需要对神经网络的第一层定义进行拓展,扩展为 784+10 个输入值。
self.model = nn.Sequential(
nn.Linear(784+10, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 1),
nn.Sigmoid()
)
除了上述修改,我们对训练函数 train()
进行修改,一方面在调用 forward()
函数时,增加标签的参数。
下面的代码显示了 train()
函数的顶部:
def train(self, inputs, label_tensor):
# 计算网络的输出
outputs = self.forward(inputs, label_tensor)
下面测试鉴别器,通过更新训练循环来将附加的标签张量输入到 train()
函数中:
for label, image_data_tensor, label_tensor in mnist_dataset:
# 真实数据
D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
# 虚假数据
D.train(generate_random_image(784), generate_random_one_hot(10), torch.FloatTensor([0.0]))
pass
- 上述代码,在随机图像的同时需要一个随机类别标签,因此需要创建一个辅 助函数
generate_random_one_hot()
来生成一个随机独热标签向量。
# 这里的 size 必须是一个整数 integer
def generate_random_one_hot(size):
label_tensor = torch.zeros((size))
random_idx = random.randint(0, size-1)
label_tensor[random_idx] = 1.0
return label_tensor
运行鉴别器代码后,观察鉴别器训练的损失值结果:
- 损失值图表并没有大的改变,而且看起来与原始的鉴别器完全相同。
修改生成器,给它提供种子和标签张量,同样需要先修改 forward()
函数来合并两个张量,便于输入到神经网络中。
def forward(self, seed_tensor, label):
# 合并种子和标签张量
inputs = torch.cat((seed_tensor, label_tensor))
return self.model(inputs)
同时,需要对网络的第一层进行修改,在之前 100 个节点的基础上,增加 10个节点,便于接收 10 个标签值:
self.model = nn.Sequential(
nn.Linear(100+10, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 784),
nn.Sigmoid()
)
最后,和鉴别器部分类似,train()
函数也需要接收标签向量:
def train(self, D, inputs, label_tensor, targets):
# 计算网络的输出
g_output = self.forward(inputs, label_tensor)
# 通过 Discriminator 鉴别器
d_output = D.forward(g_output, label_tensor)
- 需要注意的是,在输入到生成器自身的
forward()
函数以及传输生成的图像到鉴别器的forward()
函数时,需要使用相同的标签张量。
更新 GAN 训练的主循环,用于传输标签张量到鉴别器和生成器。下面
仅显示了每个 epoch 循环里的代码:
for label, image_data_tensor, label_tensor in mnist_dataset:
# 使用正确的图像(和标签)来训练鉴别器
D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
# 产生随机的 1-hot 标签供生成器输入
random_label = generate_random_one_hot(10)
# 使用错误的图像(和标签)训练鉴别器
# 使用 detach(),使得生成器的梯度不被更新
D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0])
# 为生成器生成不同的随机 1-hot 标签
random_label = generate_random_one_hot(10)
# 训练生成器
G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))
pass
- 需要注意的是,前面创建了一个
random_label
变量,所以可以在使用生成的图像训练鉴别器时,同时对生成器和鉴别器使用相同的随机标签张量。
新建一个 plot_images()
函数,用于显示特定标签的图像:
def plot_images(self, label):
label_tensor = torch.zeros((10))
label_tensor[label] = 1.0
# 生成 3 列, 2 行的图像
f, axarr = plt.subplots(2, 3, figsize=(16, 8))
for i in range(2):
for j in range(3):
axarr[i, j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28, 28), interpolation='none', cmap='Blues')
pass
pass
pass
- 该函数使用整数形式的 标签(label),生成一个独热张量,并输入到生成器中。 下面使用 6 个不同的种子来生成 6 个输出图像,显示在网格中。
我们训练了 12 个 epochs 的 GAN,大约花费了 1 小时 30 分钟。
下面是鉴别器的损失值:
- 初看起来,鉴别器的损失值和之前 GAN 的损失值看起来很相似,但更细致的来看,损失值并不是完全接近 0,而事实上是在上升。
- 由于 GAN 的理想损失值并不是 0,所以这个结果令人鼓舞。
下面是生成器的损失值:
- 生成器的损失值也与原始的 GAN 看起来很相似,如果仔细看的话,其平均值看起来也不是 0,这很不错。
上面两个图表明使用附加的标签信息能帮助 GAN 训练。
分析原因,是由于鉴别器具有更多有价值的信息能够帮助它确定图像是否真实,然后将其反馈给生成器。
最后,使用 plot_images(9)
来要求 GAN 生成数字 9 的图像:
- 上图结果说明,这个网络可以工作。这个条件 GAN 确实可以生成数字 9 的图像,甚至比之前的结果更好一些,这些图像也并不完全相同。
分别使用 plot_images(9)
、plot_images(3)
、plot_images(1)
、plot_images(5)
生成图像,如下图:
- 可以看到 GAN 产生了我们想要数字的图像,而且这些图像不完全相同。
能够产生给定类别的多样性强的图像真的很强大。
我们可以预想很多应用,比如可以产生给定表情的脸部图像,或者给定颜色的花朵。但是问题的关键要求是训练数据需要有我们感兴趣类别的标签。
- 本文用到的 MNIST 的条件 GAN 代码在下面链接:
- https://github.com/makeyourownneuralnetwork/gan/blob/master/16_cgan_mnist.ipynb
- 与普通 GAN 不同,条件 GAN(conditional GAN)可以直接生成想要类别的输出;
- 条件 GAN 是通过增加了 类别标签(class label) 的图像和种子来分别训练鉴别器和生成器;
- 条件 GAN 一般比等效的没有标签信息的 GAN 生成的图像质量更高。