条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)

  • 本文是 Make Your First GAN With PyTorch 的第 10 章,本书的介绍详见这篇文章。
  • 本文是这本书正文的最后一章,后面我会补充完这本书的总结和附录。

本文目录

  • 1. 条件 GAN 的架构
  • 2. 鉴别器
  • 3. 生成器
  • 4. 训练循环
  • 5. 显示图像
  • 6. 条件 GAN 的结果
  • 7. 全文总结


前面,我们使用 GAN生成了 MNIST 手写体数字,使用 CelebA 训练后生成了大量的人脸图像。

考虑一下,如果 GAN 能产生多样图像的同时,可以限制生成的图像为训练数据的某一类将会很有用。

  • 比如,我们能要求 GAN 来产生数字 3 的不同图像;或者,如果使用的脸部图像训练数据中有情感的标注,就能要求 GAN 仅仅产生 “快乐” 的脸。

1. 条件 GAN 的架构

能实现生成指定条件的 GAN 网络架构,应该是什么形式的呢?

  • 首先,如果想要训练好的 GAN 的生成器输出给定类别的图像,需要告诉网络我们想要哪种类别,这意味着需要在随机种子之外,需要将所需的 “类别” 信息作为生成器的输入。

  • 更进一步,之前的鉴别器仅用于区分真实图像和生成的图像,现在则需要具备同时区分图像及其类别标签的能力。否则,将不能提供给生成器任何有关图像标签的反馈。

基于上述两点,下图是 条件 GAN(Conditional GAN) 的架构:

条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)_第1张图片

  • 其中,最关键的变化是生成器和鉴别器在输入图像数据时,都有了类别标签。

2. 鉴别器

  • 以之前 全连接 MNIST GAN 的代码为基础,通过适当修改获得这个 GAN

我们需要通过同时输入图像像素值和类别标签信息来更新鉴别器。

所以,我们拓展 forward() 函数,使之能同时将图像张量和类别张量合并并同时输入。其中标签张量是 独热张量(one­-hot tensor),在数据集类中也需要备好。

def forward(self, image_tensor, label_tensor):
    # 将图像张量和和标签张量进行合并
    inputs = torch.cat((image_tensor, label_tensor))
    return self.model(inputs)
  • 代码中, torch.cat() 函数是使用一个张量来扩展另一个张量。
  • 其中,图像张量的长度为 784,而标签张量的长度为 10,故结合体的长度为 794

由于扩展了输入尺寸,所以需要对神经网络的第一层定义进行拓展,扩展为 784+10 个输入值。

self.model = nn.Sequential(
    nn.Linear(784+10, 200),
    nn.LeakyReLU(0.02),
    
    nn.LayerNorm(200),
    
    nn.Linear(200, 1),
    nn.Sigmoid()
)

除了上述修改,我们对训练函数 train() 进行修改,一方面在调用 forward() 函数时,增加标签的参数。

下面的代码显示了 train() 函数的顶部:

def train(self, inputs, label_tensor):
    # 计算网络的输出
    outputs = self.forward(inputs, label_tensor)

下面测试鉴别器,通过更新训练循环来将附加的标签张量输入到 train() 函数中:

for label, image_data_tensor, label_tensor in mnist_dataset:
    # 真实数据
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
    # 虚假数据
    D.train(generate_random_image(784), generate_random_one_hot(10), torch.FloatTensor([0.0]))
    pass 
  • 上述代码,在随机图像的同时需要一个随机类别标签,因此需要创建一个辅 助函数 generate_random_one_hot() 来生成一个随机独热标签向量。
# 这里的 size 必须是一个整数 integer
def generate_random_one_hot(size):
    label_tensor = torch.zeros((size))
    random_idx = random.randint(0, size-1)
    label_tensor[random_idx] = 1.0
    return label_tensor

运行鉴别器代码后,观察鉴别器训练的损失值结果:

条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)_第2张图片

  • 损失值图表并没有大的改变,而且看起来与原始的鉴别器完全相同。

3. 生成器

修改生成器,给它提供种子和标签张量,同样需要先修改 forward() 函数来合并两个张量,便于输入到神经网络中。

def forward(self, seed_tensor, label):
    # 合并种子和标签张量
    inputs = torch.cat((seed_tensor, label_tensor))
    return self.model(inputs)

同时,需要对网络的第一层进行修改,在之前 100 个节点的基础上,增加 10个节点,便于接收 10 个标签值:

self.model = nn.Sequential(
    nn.Linear(100+10, 200),
    nn.LeakyReLU(0.02),
    
    nn.LayerNorm(200),
    
    nn.Linear(200, 784),
    nn.Sigmoid()
)

最后,和鉴别器部分类似,train() 函数也需要接收标签向量:

def train(self, D, inputs, label_tensor, targets):
    # 计算网络的输出
    g_output = self.forward(inputs, label_tensor)
    
    # 通过 Discriminator 鉴别器
    d_output = D.forward(g_output, label_tensor)
  • 需要注意的是,在输入到生成器自身的 forward() 函数以及传输生成的图像到鉴别器的 forward() 函数时,需要使用相同的标签张量。

4. 训练循环

更新 GAN 训练的主循环,用于传输标签张量到鉴别器和生成器。下面
仅显示了每个 epoch 循环里的代码:

for label, image_data_tensor, label_tensor in mnist_dataset:
    # 使用正确的图像(和标签)来训练鉴别器
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
    
    # 产生随机的 1-hot 标签供生成器输入
    random_label = generate_random_one_hot(10)
    
    # 使用错误的图像(和标签)训练鉴别器
    # 使用 detach(),使得生成器的梯度不被更新
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0])
    
    # 为生成器生成不同的随机 1-hot 标签
    random_label = generate_random_one_hot(10)
    
    # 训练生成器
    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))
    
    pass
  • 需要注意的是,前面创建了一个 random_label 变量,所以可以在使用生成的图像训练鉴别器时,同时对生成器和鉴别器使用相同的随机标签张量。

5. 显示图像

新建一个 plot_images() 函数,用于显示特定标签的图像:

def plot_images(self, label):
    label_tensor = torch.zeros((10))
    label_tensor[label] = 1.0
    # 生成 3 列, 2 行的图像
    f, axarr = plt.subplots(2, 3, figsize=(16, 8))
    for i in range(2):
        for j in range(3):
            axarr[i, j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28, 28), interpolation='none', cmap='Blues')
            pass
        pass
    pass
  • 该函数使用整数形式的 标签(label),生成一个独热张量,并输入到生成器中。 下面使用 6 个不同的种子来生成 6 个输出图像,显示在网格中。

6. 条件 GAN 的结果

我们训练了 12 个 epochs 的 GAN,大约花费了 1 小时 30 分钟。

下面是鉴别器的损失值:

条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)_第3张图片

  • 初看起来,鉴别器的损失值和之前 GAN 的损失值看起来很相似,但更细致的来看,损失值并不是完全接近 0,而事实上是在上升。
  • 由于 GAN 的理想损失值并不是 0,所以这个结果令人鼓舞。

下面是生成器的损失值:

条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)_第4张图片

  • 生成器的损失值也与原始的 GAN 看起来很相似,如果仔细看的话,其平均值看起来也不是 0,这很不错。

上面两个图表明使用附加的标签信息能帮助 GAN 训练。

分析原因,是由于鉴别器具有更多有价值的信息能够帮助它确定图像是否真实,然后将其反馈给生成器。

最后,使用 plot_images(9) 来要求 GAN 生成数字 9 的图像:

条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)_第5张图片

  • 上图结果说明,这个网络可以工作。这个条件 GAN 确实可以生成数字 9 的图像,甚至比之前的结果更好一些,这些图像也并不完全相同。

分别使用 plot_images(9)plot_images(3)plot_images(1)plot_images(5) 生成图像,如下图:

条件生成对抗网络(Conditional GAN)(Make Your First GAN With PyTorch 第10章)_第6张图片

  • 可以看到 GAN 产生了我们想要数字的图像,而且这些图像不完全相同。

能够产生给定类别的多样性强的图像真的很强大。

我们可以预想很多应用,比如可以产生给定表情的脸部图像,或者给定颜色的花朵。但是问题的关键要求是训练数据需要有我们感兴趣类别的标签。

  • 本文用到的 MNIST 的条件 GAN 代码在下面链接:
  • https://github.com/makeyourownneuralnetwork/gan/blob/master/16_cgan_mnist.ipynb

7. 全文总结

  • 与普通 GAN 不同,条件 GANconditional GAN)可以直接生成想要类别的输出;
  • 条件 GAN 是通过增加了 类别标签(class label) 的图像和种子来分别训练鉴别器和生成器;
  • 条件 GAN 一般比等效的没有标签信息的 GAN 生成的图像质量更高。

你可能感兴趣的:(Pytorch,Make,First,GAN,With,PyTorch,Python学习笔记,pytorch,生成对抗网络,深度学习)