数据集应用|如何简单使用PyTorch搭建GAN模型

作者:Ta-Ying Cheng,牛津大学博士研究生,Medium技术博主,多篇文章均被平台官方刊物Towards Data Science收录

以往人们普遍认为生成图像是不可能完成的任务,因为按照传统的机器学习思路,我们根本没有真值(ground truth)可以拿来检验生成的图像是否合格。2014年,Goodfellow等人则提出生成对抗网络(Generative Adversarial Network, GAN),能够让我们完全依靠机器学习来生成极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图像生成领域发生了巨变。本文将带大家了解GAN的工作原理,并介绍如何通过PyTorch简单上手GAN。

    01    

GAN的原理

按照传统的方法,模型的预测结果可以直接与已有的真值进行比较。然而,我们却很难定义和衡量到底怎样才算作是“正确的”生成图像。Goodfellow等人则提出了一个有趣的解决办法:我们可以先训练好一个分类工具,来自动区分生成图像和真实图像。这样一来,我们就可以用这个分类工具来训练一个生成网络,直到它能够输出完全以假乱真的图像,连分类工具自己都没有办法评判真假。

数据集应用|如何简单使用PyTorch搭建GAN模型_第1张图片

图 1. GAN的运作流程. 图源作者.

按照这一思路,我们便有了GAN:也就是一个生成器(generator)和一个判别器(discriminator)。生成器负责根据给定的数据集生成图像,判别器则负责区分图像是真是假。GAN的运作流程如图1所示。

损失函数

在GAN的运作流程中,我们可以发现一个明显的矛盾:同时优化生成器和判别器是很困难的。可以想象,这两个模型有着完全相反的目标:生成器想要尽可能伪造出真实的东西,而判别器则必须要识破生成器生成的图像。为了说明这一点,我们设D(x)为判别器的输出,即x是真实图像的概率,并设G(z)为生成器的输出。判别器类似于一种二进制的分类器,所以其目标是使该函数的结果最大化:

这一函数本质上是非负的二元交叉熵损失函数。另一方面,生成器的目标是最小化判别器做出正确判断的机率,因此它的目标是使上述函数的结果最小化。因此,最终的损失函数将会是两个分类器之间的极小极大博弈,表示如下:

理论上来说,博弈的最终结果将是让判别器判断成功的概率收敛到0.5。然而在实践中,极大极小博弈通常会导致网络不收敛,因此仔细调整模型训练的参数非常重要。在训练GAN时,我们尤其要注意学习率等超参数,学习率比较小时能让GAN在输入噪音较多的情况下也能有较为统一的输出。


    02    

计算环境


本文将指导大家通过PyTorch搭建整个程序(包括torchvision)。同时,我们将会使用Matplotlib来让GAN的生成结果可视化。以下代码能够导入上述所有库:

"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt


数据集

数据集对于训练GAN来说非常重要,尤其考虑到我们在GAN中处理的通常是非结构化数据(一般是图片、视频等),任意一class都可以有数据的分布。这种数据分布恰恰是GAN生成输出的基础。为了更好地演示GAN的搭建流程,本文将带大家使用最简单的MNIST数据集,其中含有6万张手写阿拉伯数字的图片。

像MNIST这样高质量的非结构化数据集都可以在格物钛的公开数据集网站上找到。事实上,格物钛Open Datasets平台涵盖了很多优质的公开数据集,同时也可以实现数据集托管及一站式搜索的功能,这对AI开发者来说,是相当实用的社区平台。

数据集应用|如何简单使用PyTorch搭建GAN模型_第2张图片

硬件需求

一般来说,虽然可以使用CPU来训练神经网络,但最佳选择其实是GPU,因为这样可以大幅提升训练速度。我们可以用下面的代码来测试自己的机器能否用GPU来训练:

"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    03    

实现

网络结构

由于数字是非常简单的信息,我们可以将判别器和生成器这两层结构都组建成全连接层(fully connected layers)。我们可以用以下代码在PyTorch中搭建判别器和生成器:

"""
Network Architectures
The following are the discriminator and generator architectures
"""

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

def forward(self, x):
    x = self.activation(self.fc1(x))
    x = self.activation(self.fc2(x))
    x = self.fc3(x)
    x = x.view(-1, 1, 28, 28)
    return nn.Tanh()(x)


训练

在训练GAN的时候,我们需要一边优化判别器,一边改进生成器,因此每次迭代我们都需要同时优化两个互相矛盾的损失函数。对于生成器,我们将输入一些随机噪音,让生成器来根据噪音的微小改变输出的图像:

"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # Training the discriminator
        # Real inputs are actual images of the MNIST dataset
        # Fake inputs are from the generator
        # Real inputs should be classified as 1 and fake as 0
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)

        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # Training the generator
        # For generator, goal is to make the discriminator believe everything is 1
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')


    04    

结果

经过100个训练时期之后,我们就可以对数据集进行可视化处理,直接看到模型从随机噪音生成的数字:

数据集应用|如何简单使用PyTorch搭建GAN模型_第3张图片

图2. GAN的生成结果.图源作者.

我们可以看到,生成的结果和真实的数据非常相像。考虑到我们在这里只是搭建了一个非常简单的模型,实际的应用效果会有非常大的上升空间。


    05    

不仅是有样学样


GAN和以往机器视觉专家提出的想法都不一样,而利用GAN进行的具体场景应用更是让许多人赞叹深度网络的无限潜力。下面我们来看一下两个最为出名的GAN延申应用。

CycleGAN

朱俊彦等人2017年发表的CycleGAN能够在没有配对图片的情况下将一张图片从X域直接转换到Y域,比如把马变成斑马、将热夏变成隆冬、把莫奈的画变成梵高的画等等。这些看似天方夜谭的转换CycleGAN都能轻松做到,并且结果非常准确。

数据集应用|如何简单使用PyTorch搭建GAN模型_第4张图片

图3. 朱俊彦等人提供的CycleGAN生成案例. 图源:Github page.

GauGAN

‍英伟达则通过GAN让人们能够只需要寥寥数笔勾勒出自己的想法,便能得到一张极为逼真的真实场景图片。虽然这种应用需要的计算成本极为高昂,但是GauGAN凭借它的转换能力探索出了前所未有的研究和应用领域。‍

数据集应用|如何简单使用PyTorch搭建GAN模型_第5张图片

图3. GauGAN的生成样例. 左侧是用户输入的简笔画, 右侧是模型生成的图片.图源作者.

    06    

结语


相信看到这里,你已经知道了GAN的大致工作原理,并且能够自己动手简单搭建一个GAN了。

数据集应用|如何简单使用PyTorch搭建GAN模型_第6张图片

登录Open Datasets 免费获取数据集

数据集应用|如何简单使用PyTorch搭建GAN模型_第7张图片

关注公众号 了解格物钛更多资讯

数据集应用|如何简单使用PyTorch搭建GAN模型_第8张图片

数据集应用|如何简单使用PyTorch搭建GAN模型_第9张图片点击阅读原文,查看格物钛更多公开数据集

你可能感兴趣的:(机器学习,人工智能,深度学习,神经网络,数据挖掘)