深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现

文章目录

      • 生成对抗网络GAN与Pytorch实现
        • 1、生成对抗网络(GAN)是什么?
        • 2、如何训练GAN?
        • 3、 训练DCGAN实现人脸生成
          • (1)网络结构
          • (2)Pytorch实现
        • 4、 GAN的应用

生成对抗网络GAN与Pytorch实现

1、生成对抗网络(GAN)是什么?

所谓的生成对抗网络,就是一种可以生成特定分布数据的神经网络模型

  • GAN网络结构

深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现_第1张图片

如上图所示, G A N GAN GAN网络结构中,最重要的是两个模块: G G G D D D,输入的数据,通过 G G G生成了一些伪数据 G ( z ) G(z) G(z),然后与真实数据 X X X一同输入到 D D D模块,然后进行判断是真实数据还是伪数据

2、如何训练GAN?

  • 训练目的:
  • 对于D:对真样本输出高概率
  • 对于G:输出使D会给出高概率的数据

从上面的训练目的可以看出,GAN的训练与传统的监督学习训练模型并不相同,传统的监督学习训练模型只能做数据的映射,输入的数据通过模型后得到输出值,随后构造损失函数衡量输出值与真实标签中间的差异,将这个差异值求导并采用梯度下降的方法更新模型中的参数,从而使得模型的输出逼近真实的标签值,在这种监督学习模型中,一个很核心的模块就是损失函数模块,而在GAN的训练模型中并没有损失函数模块,这是GAN训练模型与监督学习模型最大的不同,GAN训练模型输入的是随机数,通过 G G G模块输出了“伪数据”,但这里并不会构造损失函数去比较输出的“伪数据”与真实数据之间的差异,这是毫无意义的,“伪数据”与真实数据之间的差异通过 D D D模型获得,这里的 D D D模块就充当了监督学习中的损失函数模块的角色,得到差异值后,类似的对 G G G中的参数进行更新,从而使得“伪数据”逼近于真实的训练数据,需要注意的是这里的“逼近”并不是数值上的逼近,而是分布上的逼近

  • 训练步骤:
  • 第一步:先训练模块 D D D,使判别器具有判断真假的能力
    • 输入:真实数据加G生成的假数据
    • 输出:二分类概率
  • 第二步:再训练模块 G G G
    • 输入:随机噪声 z z z
    • 输出:分类概率: D ( G ( z ) ) D(G(z)) D(G(z))

深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现_第2张图片

  • 训练伪代码:

深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现_第3张图片

3、 训练DCGAN实现人脸生成

(1)网络结构

所谓DCGAN,就是利用卷积网络,实现GAN,即 D D D模块, G G G模块都使用卷积神经网络实现

  • Generator

深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现_第4张图片

从上图中可以看出,输入的是一个长度为100的张量,但是在Pytorch中的必须理解为一个四维张量——(1,100,1,1),最终得到的是一个 3 × 64 × 64 3\times64\times64 3×64×64的RGB图像

  • Discriminator

深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现_第5张图片

从上图中可以看出,输入的是一个 3 × 64 × 64 3\times64\times64 3×64×64的RGB图像,输出的是一个长度为2的向量,用于判断正负样本

(2)Pytorch实现

Pytorch需要对Generator和Discriminator进行实现:

class Generator(nn.Module):
    def __init__(self, nz=100, ngf=128, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=128):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

4、 GAN的应用

GAN应用范围非常广,可以参看此链接

你可能感兴趣的:(深度之眼Pytorch框架训练营第四期——生成对抗网络GAN与Pytorch实现)