生成对抗网络

目录

0. Abstract

1. Introduction

2. Relatedwork

3.Experiments

4.Advantages and disadvantages

5.Conclusions and future work(idea)

6. 网络训练源代码


0. Abstract

我们提出了一个新的框架,通过一个对抗的过程来估计生成模型,在此过程中我们同时训练两个模型:一个生成模型G捕获数据分布,和一种判别模型D,它估计样本来自训练数据而不是G的概率。G的训练程序是最大化D犯错的概率,这个框架对应于一个极小极大的双人游戏。在任意函数G和D的空间中,存在唯一解,G可以重现训练数据分布,D处处等于1/2。在G和D由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。在训练或生成样本的过程中,不需要任何马尔科夫链或展开的近似推理网络。通过对生成的样本进行定性和定量评估,实验证明了该框架的潜力。

1. Introduction

深度学习的前景是发现丰富的分层模型,它代表人工智能应用中遇到的各种数据的概率分布,如自然图像、包含语音的音频波形和自然语言语料库中的符号。到目前为止,在深度学习中最显著的成功涉及到判别模型,通常是那些将高维、丰富的感官输入映射到类标签的模型。这些惊人的成功主要是基于反向传播和dropout算法,使用分段线性单元,具有特别良好的梯度。由于在极大似然估计和相关策略中出现的许多难以处理的概率计算的近似性,以及由于难以在生成环境中利用分段线性单元的优点,深度生成模型的影响较小。我们提出了一种新的生成模型估计方法来克服这些困难。

在提出的对抗网框架中,生成模型与对手进行了比较:一个学习确定样本是来自模型分布还是来自数据分布的判别模型。生成模型可以被认为类似于一组伪造者,他们试图制造假币并在不被发现的情况下使用它,而判别模型则类似于警察,试图发现假币,这个游戏的竞争促使两队改进他们的方法,直到仿冒品无法从真品中辨别出来。

该框架可以生成针对多种模型的特定训练算法和优化算法,在这篇文章中,我们探讨了生成模型通过一个多层感知器传递随机噪声来生成样本的特殊情况,而判别模型也是一个多层感知器,我们把这种特殊情况称为对抗网络。在这种情况下,我们可以只使用非常成功的反向传播和dropout算法来训练这两个模型,并且只使用正向传播来训练生成模型的样本,不需要近似推论或马尔科夫链。

2. Relatedwork

有潜在变量的有向图形模型的另一种选择是有潜在变量的无向图形模型,如限制玻尔兹曼机(RBMs),深玻尔兹曼机(DBMs)及其众多变体。这些模型中的相互作用被表示为未归一化势函数的乘积,由随机变量所有状态的全局求和/积分进行归一化。这个数量(配分函数)和它的梯度是棘手的,但最琐碎的情况下,虽然他们可以由马尔可夫链蒙特卡罗(MCMC)方法估计。对于依赖于MCMC的学习算法来说,混合是一个很重要的问题。

深度置信网络(DBNs)[16]是包含一个无向层和多个有向层的混合模型。虽然存在一种快速的分层近似训练准则,但DBNs存在与无向和有向模型相关的计算困难。

也有人提出了不近似或不限制对数似然的替代标准,如分数匹配和噪声对比估计(NCE),这两种方法都要求所学习的概率密度被解析指定为一个归一化常数。请注意,在许多具有多层潜在变量(如DBNs和DBMs)的有趣生成模型中,甚至不可能导出可处理的非规范化概率密度,一些模型,如去噪自动编码器[30]和收缩自动编码器的学习规则非常类似于分数匹配应用于RBMs。在NCE中,与本文一样,使用了判别训练准则来拟合生成模型。然而,生成模型本身用于从固定噪声分布的样本中区分生成的数据,而不是拟合一个单独的判别模型。由于NCE使用一个固定的噪声分布,当模型学习到即使是在观察变量的一个小子集上的一个近似正确的分布之后,学习速度也会显著减慢。

最后,一些技术不涉及明确定义概率分布,而是训练生成机器从期望的分布中抽取样本,这种方法的优点是可以通过反向传播来训练这些机器。近期主要的工作包括生成随机网络(GSN)框架:它扩展了广义去噪自动编码器:两者都可以看作是定义一个参数化的马尔科夫链,即一个人学习机器的参数,执行一个步骤的生成马尔科夫链。与GSNs相比,对抗网的采样不需要马尔科夫链,由于反求网络在生成过程中不需要反馈环,所以它们能够更好地利用分段线性单元,这提高了反向传播的性能,但在使用反馈环时存在无限制激活的问题。通过反向传播训练生成机器的最新例子包括自动编码变分贝叶斯和随机反向传播。

当模型都是多层感知器时,对抗性建模框架最容易应用。为了学习生成器在数据x上的分布pg,我们定义了一个输入噪声变量pz (z), G (z;θg)表示将噪声变量映射到数据空间, G是一个可微函数,表示为一个参数为θg的多层感知器。我们还定义一个多层感知器D (x;θd)输出一个标量,D(x)表示x来自数据集而不是pg的概率。我们训练D最大限度地将正确的标签分配给训练样本和来自G的样本的概率,我们同时训练G,使得 log(1 - D(G(z))) 最小化。

换句话说,D和G玩了一个具有值函数V (G,D)的二人极大极小博弈:

在下一节中,我们将对对抗网进行理论分析,主要说明当G和D具有足够的容量时,训练准则允许恢复数据生成分布,例如在非参数极限下。请参见图1,其中对该方法进行了不太正式的、更具教育性的解释。在实践中,我们必须使用迭代的数值方法来实现游戏。优化完成内环的训练在计算上是禁止的,对于有限的数据集会导致过度拟合。相反,我们在优化D的k个步骤和优化G的一个步骤之间交替进行,只要G变化足够慢,D就会保持在其最优解附近,这种策略类似于SML/PCD:训练从一个学习步骤到下一个学习步骤保持来自马尔可夫链的样本,该过程在算法1中正式给出。

在实际应用中,公式1可能无法为G提供足够的梯度来学习。在学习的早期,当G较差时,D可以很有信心地拒绝样本,因为它们与训练数据明显不同。在这种情况下,log(1 - D(G(z)))饱和,与其训练G去最小化log(1 - D(G(z))不如训练G去最大化logD(G(z))这一目标函数的结果与动态函数相同,但在学习中提供了更强的学习效果。

生成对抗网络_第1张图片

注:图中的黑色虚线表示真实的样本的分布情况,蓝色虚线表示判别器判别概率的分布情况,绿色实线表示生成样本的分布。Z表示噪声,Z到x表示通过生成器之后的分布的映射情况。
我们的目标是使用生成样本分布(绿色实线)去拟合真实的样本分布(黑色虚线),来达到生成以假乱真样本的目的。
可以看到在(a)状态处于最初始的状态的时候,生成器生成的分布和真实分布区别较大,并且判别器判别出样本的概率不是很稳定,因此会先训练判别器来更好地分辨样本。
通过多次训练判别器来达到(b)样本状态,此时判别样本区分得非常显著和良好。然后再对生成器进行训练。
训练生成器之后达到(c)样本状态,此时生成器分布相比之前,逼近了真实样本分布。
经过多次反复训练迭代之后,最终希望能够达到(d)状态,生成样本分布拟合于真实样本分布,并且判别器分辨不出样本是生成的还是真实的(判别概率均为0.5)。也就是说我们这个时候就可以生成出非常真实的样本啦,目的达到。[2]

3.Experiments

包括MNIST, theTorontoFace Database (TFD),和CIFAR-10一系列数据集上训练了对抗网络。生成网络使用rectifier linear and sigmoid两种激活函数,而判别器使用maxout激活。应用dropout训练判别器网络。虽然我们的理论框架允许在生成器的中间层使用dropout和其他噪声,但我们只使用噪声作为生成器网络最底层的输入。

4.Advantages and disadvantages

与以前的建模框架相比,这个新框架有优点也有缺点。缺点主要是没有显式表示的pg (x),在训练时D必须与G同步。它的优点是不需要使用马尔科夫链,只使用backprop来获得梯度,在学习过程中不需要推理,可以将多种函数合并到模型中。

5.Conclusions and future work(idea)

  1. 将c作为G和D的输入,可以得到条件生成模型p(x | c)。
  2. 学习近似推理:可以利用一个辅助网络在给定x时来预测z。这与wake-sleep算法训练的推理网络类似,但具有在生成器网络完成训练后,可以对固定生成器网络进行推理网络训练的优点。
  3. 通过训练一系列共享参数的条件模型,可以近似地对所有条件p(xS | x)进行建模,其中s是x指标的子集。本质上,我们可以使用对抗网来实现确定性MP-DBM[11]的随机扩展。
  4. 半监督学习:当有限的标记数据可用时,鉴别器或推理器的特性可能会降低分类器的性能。
  5. 效率改进:在培训过程中,通过划分更好的方法来协调G和D,或者确定更好的z分布,可以大大加快训练的速度。

6. 网络训练源代码

import torch.nn as nn
from torchvision import transforms
import torch
import torch.optim as op
from torchvision import datasets
from torch.utils.data import DataLoader

batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
# test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
# test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
"生成器的输入是一组噪声"
class Generator(nn.Module):
    def __init__(self, in_features=64, out_features=784):
        """
        :param in_features: 生成器的in_features,一般输入z的维度z_dim,该值可自定义
        :param out_features: 生成器的out_features,需要与真实数据的维度一致
        """
        super().__init__()
        "nn.Tanh() #用于归一化数据"
        self.gen = nn.Sequential(nn.Linear(in_features, 256),
                                 nn.LeakyReLU(0.1),
                                 nn.Linear(256, out_features),
                                 nn.Tanh()
                                 )
    def forward(self, z):
        gz = self.gen(z)
        return gz

"判别器"
class Discriminator(nn.Module):
    def __init__(self, in_features=784):
        """
        :param in_features: 真实数据的维度、同时也是生成的假数据的
        """
        super().__init__()
        "使用非饱和激活函数nn.LeakyReLU(0.1),防止梯度下降"
        "nn.Tanh() 是双曲正切函数,通常用于确保生成的输出处于特定的值范围内,例如在 -1 到 1 之间"
        self.disc = nn.Sequential(nn.Linear(in_features, 128),
                                  nn.LeakyReLU(0.1),
                                  nn.Linear(128, 1),
                                  nn.Sigmoid()
                                  )
    def forward(self, data):
        """
        :param data: 输入的data可以是真实数据时,Disc输出dx。输入的data是gz时,Disc输出dgz
        :return:
        """
        return self.disc(data)    # 输出结果为置信度


z_dim = 64
real_data_dim = 784
lr = 0.1
"判断是否有GPU存在"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"实例化判别器与生成器"
gen = Generator(in_features=z_dim, out_features=real_data_dim)
gen.to(device)
disc = Discriminator(in_features=real_data_dim).to(device)
disc.to(device)
"定义判别器与生成器所使用的优化算法"
op_disc = op.Adam(disc.parameters(), lr=lr, betas=(0.9, 0.999))
op_gen = op.Adam(gen.parameters(), lr=lr, betas=(0.9, 0.999))
"定义损失函数"
criterion = nn.BCELoss(reduction="mean")
if __name__ == "__main__":
    for epoch in range(10):
        for batch_idx, (x, _) in enumerate(data_loader):
            x = x.view(-1, 784).to(device)
            batch_size = x.shape[0]
            # 判别器反向传播==========================================================================
            "------------------------判别器对真实数据的预测概率------------------------"
            dx = disc(x).view(-1)
            "所有真实数据的损失均值"
            loss_real = criterion(dx, torch.ones_like(dx))
            loss_real.backward()
            "判别器对真实数据的预测概率 dx 的平均值,然后使用 .item() 方法将其转换为标量值,并将结果存储在 D_x 变量中"
            D_x = dx.mean().item()
            "------------------------判别器对生成数据的预测概率------------------------"
            noise = torch.randn((batch_size, z_dim)).to(device)
            "将随机噪声 noise 通过生成器模型 gen 生成假数据 gz,这些假数据模拟真实数据的特征"
            gz = gen(noise)
            "使用 gz.detach() 是为了阻止生成数据进入判别器的计算图,以确保在这里只计算判别器对生成数据的预测概率"
            dgz1 = disc(gz.detach())
            "所有生成数据的损失均值,在训练生成对抗网络(GAN)的判别器时,对于生成的数据,我们希望判别器的输出接近零,表示生成数据被正确分类为假数据。因此,我们将目标设置为与生成数据对应的标签,通常是零"
            loss_fake = criterion(dgz1, torch.zeros_like(dgz1))
            loss_fake.backward()
            "判别器对生成数据的预测概率 dx 的平均值,然后使用 .item() 方法将其转换为标量值,并将结果存储在 D_G_Z1 变量中"
            D_G_z1 = dgz1.mean().item()
            "判别器对真实数据的损失和对生成数据的损失之和。这个总损失通常用于衡量判别器的性能"
            errorD = loss_real + loss_fake
            "errorD.backward() #直接对errorD反向传播,也可分别对loss_real,loss_fake执行反向传播"
            "更新判别器上的权重"
            op_disc.step()
            "清零判别器迭代后的梯度"
            disc.zero_grad()

            # 生成器反向传播*==========================================================================
            "注意,由于在此时判别器上的权重已经被更新过了,所以dgz的值会变化,需要重新生成"
            "得到判别器对生成数据的输出 dgz2"
            dgz2 = disc(gz)
            "计算了生成器的损失。与判别器的损失不同,这里我们希望生成器生成的数据被判别器识别为真实数据,所以我们使用目标值为1的损失函数来计算生成器的损失"
            Gloss = criterion(dgz2, torch.ones_like(dgz2))
            "反向传播"
            Gloss.backward()
            "更新生成器上的权重"
            op_gen.step()
            "清零生成器更新后梯度"
            gen.zero_grad()
            D_G_z2 = dgz2.mean().item()
            # print(f"第{ epoch+1 }次训练")

你可能感兴趣的:(经典原文模型,生成对抗网络,人工智能,神经网络)