AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成

1 生成对抗网络

生成对抗网络(Generative Adversarial Networks, GAN)是一种非常经典的生成式模型,它受到双人零和博弈的启发,让两个神经网络在相互博弈中进行学习,开创了生成式模型的新范式。从  2017 年以后,GAN相关的论文呈现井喷式增长。GAN 的应用十分广泛,它的应用包括图像合成、图像编辑、风格迁移、图像超分辨率以及图像转换,数据增强等。

1.1 背景

具有开创性工作的生成对抗网络原文由Goodfellow在2014年发表,当时深度学习领域最好的成果有很大一部分都是判别式模型(比如AlexNet),它们使用反向传播和dropout方法,让模型能够拥有一个良好的梯度结构,从而更迅速地收敛到一个较好的状态。而此时的生成式模型相比之下效果却并不是很好。

论文地址:https://arxiv.org/pdf/1406.2661.pdf

生成式模型的任务是:给你一组原始数据 ,请你生成一组新数据 ,使得这两组数据看起来“尽可能的相似”。而这个任务本质上就是让新数据的概率密度 尽可能地接近原始数据的概率密度 ,也就是说我们要学习 这个函数。传统的生成式模型会将 建模出来,并通过梯度下降优化模型中的参数,最终达到逼近的目的。从理论上来这个方法似乎挺好,但是真实情况下的概率分布往往是难以逼近的,在当时也并没有今天这么多的优化技术,所以实际上这个方法效果一般。

因此,本文作者提供了一种新的思路:我们不去直接逼近 ,而是添加一个判别器来判断生成出来的数据“像不像”原始数据,并且轮流训练这个判别器和生成器。假如这两个model最终都能收敛,那么收敛到的位置一定是生成效果“最像”,而判别效果“最准”的位置,我们直接取此时的生成器作为最终答案即可。原文中举了一个例子:生成器就好比一个造假币的团伙,判别器就好比警察,现在我们来看他们互相博弈的过程:

  1. 第一次,假币团伙没有经验,造了很多7块钱一张的假币,拿出去一花就被警察抓了。
  2. 第二次,假币团伙学聪明了,造了很多看起来很正常的假币;警察一开始没有发现,但是在拿到了一张假币样本之后开始研究,发现了假币的一个缺陷(比如没有让盲人摸的手感线),于是向民众推广这种辨别假币的方法,最终假币团伙的假币就花不出去了。
  3. 第三次,假币团伙修复了上面的缺陷,造了很多新假币;警察又开始寻找新的缺陷并推广……
  4. 久而久之,假币团伙造的假币会越来越接近真币,而警察辨别假币的水平也会越来越高。就像生物学上的协同进化一样,这两个团体会在互相博弈中共同进步。

1.2 工作原理

GAN网络能够在不使用标注数据的情况下来进行生成任务的学习。GAN网络由一个生成器和一个判别器组成。生成器从潜在空间随机取样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别器的输入则为真实样本或生成器的输出,其目的是将生成器的输出从真实样本中尽可能分别出来。生成器和判别器相互对抗、不断学习,最终目的使得判别器无法判断生成器的输出结果是否真实。

1.2.1 GAN网络组成

AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第1张图片

GAN网络由生成器和判别器组成:

  • 生成器。生成器学习如何生成看似合理的数据。对判别器来说,这些生成的实例会变成负面训练样本。
  • 判别器。判别器学习如何通过真实数据的学习来辨别出生成器生成的假数据。判别器将惩罚生成器生成的“不可理(假)”的数据结果。

1.2.2 判别器如何工作

判别器在生成对抗网络中,简单来说是一个分类器。该分类器尝试从生成器生成的假数据中识别真实数据。它可以使用任何适用于数据分类的网络架构。判别器在训练中使用误差反向传播机制来计算损失和更新权重参数。

AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第2张图片

判别器的训练数据有两处来源,

  • 真实数据。真实数据实例,比如人的照片。在训练中,判别器把这些实例用作正面样本。
  • 虚假数据。生成器生成的实例。在训练中,判别器把这些实例用作负面样本。

上图中两个Sample的框就是这两种输入到判别器的样本。注意,在判别器训练时,生成器不会训练,即在生成器为判别器生成示例数据时,生成器的权重保持恒定。

在训练判别器时,判别器连接到两个损失函数。在训练时,判别器忽略生成器的损失而只使用判别器损失。在训练过程中,

  • 判别器对真实数据和来自生成器生成的假数据进行分类。
  • 判别器的损失函数将惩罚由判别器产生的误判,比如把真实实例判定成假,或者把假的实例判定为真。
  • 判别器通过对来自于判别器网损失函数计算的损失进行反向传播。如上图。

1.2.3 生成器如何工作

生成对抗网络里的生成器,通过接受来自于判别器的反馈来学习如何创建假数据。生成器学习如何让(欺骗)判别器把它的输出归类为真实数据。

AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第3张图片

相对于判别器的训练,生成器的训练要求生成器与判别器有更加紧密的集成。生成器训练包含:

  • 随机输入。神经网络需要某种形式的输入。通常,为了达到某种目的而输入数据,比如一个输入的实例用来进行分类任务或者预测。但当希望输出一整个全新的数据实例,用什么样输入数据呢?

  • 最常见基础形式里,GAN使用随机噪音作为它的输入。然后,生成器将把随机噪音转换成有意义的输出。通过引入噪音,可以从不同分布形式的不同空间采样,让GAN生成一个宽域的数据

  • 实验结果表明,不同噪音的分布不会产生太大影响。因此,可以选择相对较易的采样来源,比如,均匀分布。方便起见,噪音采样空间的维度一般小于输出空间的维度。注意,有些GAN变种不使用随机输入来形成输出。

  • 生成器网络,负责把随机输入转换成数据实例。

  • 判别器网络,负责把上一步生成的数据归类,判别器输出。

  • 生成器的损失函数,负责惩罚企图蒙骗判别器失败的情况(即生成器生成的假数据,被识别器成功识破)。

1.3 训练

 1.3.1 GAN训练的整个步骤

  • 当训练开始,生成器生成一个些很明显的假数据,判别器能快速地学习如何识别出是不是假数据,

AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第4张图片

  • 随着训练稳步推进,生成器更接近于能生成蒙骗判别器的输出数据。

  • 最终,如果生成器训练得当,在识别真实和虚假方面,判别器变得差强人意,而且将开始把假数据分类为真实数据,识别的准确率降低。

1.3.2 使用判别器训练生成器

要训练神经网络,通过修改网络的权重来减少误差或者输出的损失。在GAN里却不同,生成器不直接连接到损失函数来试图影响损失,而是把生成的数据输出到判别器,而判别器会制造影响误差损失的输出。当生成器生成的数据被判别器成功识别成仿冒时,生成器损失函数会惩罚生成器。

另外,反向传播里也包含网络的额外处理。反向传播通过计算对输出的影响——更改后的权重在多大程度上影响输,来调整每个权重以使其在正确的方向上。但,生成器权重的影响取决于直接输出到判别器的权重的影响。因此,反向传播始于输出且穿过判别器回流到生成器。

在生成器训练时,不希望判别器更改,就像尝试击中一个移动目标,会让一个本身就麻烦的问题变得更加困难。所以,在训练生成器时使用如下流程,

  • 随机噪音采样作为输入。
  • 生成器从采样的随机噪音采样里生成输出。
  • 让判断器判断上述输出是“真”或“假”,以此作为生成器的输出。
  • 从判别器的分类输出计算误差损失。
  • 穿过判别器和生成器的反向传播,从而获得梯度。
  • 使用梯度来更新生成器的权重。

这个流程是生成器训练的一个迭代。

1.3.3 交替训练

生成器和判别器有不同的训练流程,那么如何才能作为一个整体来训练GAN呢?GAN的训练有交替阶段,

  • 判别器训练一个或者多个迭代。
  • 生成器训练一个或者多个迭代。
  • 不断重复1和2步来训练生成器和判别器。

在判别器训练的阶段,保持生成器不变。因为判别器训练会尝试从仿冒数据里分辨出真实数据,判别器必须学习如何识别生成器的缺陷。这就是经过完整训练的生成器和只能生成随机输出的未训练生成器的不同之处。

类似地,在生成器训练的阶段,保持判别器不变。否则,生成器像尝试击中移动目标一样,可能永远无法收敛。

这种往复训练使得GAN能够处理另外一些棘手的数据生成问题。开始于相对较简单的分类任务问题,从而获得一个解决生成难题的立足点。相反地,如果不能训练一个分类器来识别真实数据和生成数据的区别,甚至无法识别与随机初始化输出的区别,那么GAN的训练根本无法开始。

1.3.4 收敛

随着训练进行,生成器不断改善,相对地,由于不能再轻易地识别出真实数据和假冒数据的,判别器的表现越来越差。当生成器完美地成功生成时,判别器的正确率只有50%。基本上和抛一枚硬币来预测正反一样的概率一样。

这种进度展现了作为整体GAN的一个问题,判别器的反馈随着时间的推移越来越不具有意义。过了这个节点之后,即判别器完全给出了随机反馈,如果继续训练GAN,那么生成器将使用判别器给出的无效反馈进行训练,那么生成器的质量可能会崩塌。

对于GAN来说,收敛往往是一个闪现的点,而不是牢固的、稳态的。

2 基于pytorch实现MINIST手写数字图像识别

# -*- coding: utf-8 -*-
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
from tqdm import tqdm
import os

os.chdir(os.path.dirname(__file__))


class Generator(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.linear = nn.Linear(latent_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.linear(x))

        x = torch.sigmoid(self.out(x))
        return x


class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size,1)

    def forward(self, x):
        x = F.relu(self.linear(x))
        x = torch.sigmoid(self.out(x))
        return x


loss_BCE = torch.nn.BCELoss(reduction='sum')


# 压缩后的特征维度
latent_size = 16

# encoder和decoder中间层的维度
hidden_size = 128

# 原始图片和生成图片的维度
input_size = output_size = 28*28


epochs = 100
batch_size = 32
learning_rate = 1e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

modelname = ['gan-G.pth', 'gan-D.pth']
model_g = Generator(latent_size, hidden_size, output_size).to(device)
model_d = Discriminator(input_size, hidden_size).to(device)

optim_g = torch.optim.Adam(model_g.parameters(), lr=learning_rate)
optim_d = torch.optim.Adam(model_d.parameters(), lr=learning_rate)

try:
    model_g.load_state_dict(torch.load(modelname[0]))
    model_d.load_state_dict(torch.load(modelname[1]))
    print('[INFO] Load Model complete')
except:
    pass


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False)


for epoch in range(epochs):
    Gen_loss = 0
    Dis_loss = 0
    for imgs, lbls in tqdm(train_loader, desc=f'[train]epoch:{epoch}'):
        bs = imgs.shape[0]
        T_imgs = imgs.view(bs, input_size).to(device)
        T_lbl = torch.ones(bs, 1).to(device)
        F_lbl = torch.zeros(bs, 1).to(device)

        sample = torch.randn(bs, latent_size).to(device)
        F_imgs = model_g(sample)
        F_Dis = model_d(F_imgs)

        loss_g = loss_BCE(F_Dis, T_lbl)
        loss_g.backward()
        optim_g.step()
        optim_g.zero_grad()

        # 训练判别器, 使用判别器分别判断真实图像和伪造图像
        T_Dis = model_d(T_imgs)
        F_Dis = model_d(F_imgs.detach())

        loss_d_T = loss_BCE(T_Dis, T_lbl)
        loss_d_F = loss_BCE(F_Dis, F_lbl)
        loss_d = loss_d_T + loss_d_F
        loss_d.backward()
        optim_d.step()
        optim_d.zero_grad()

        Gen_loss += loss_g.item()
        Dis_loss += loss_d.item()
    print(f'epoch:{epoch}|Train G Loss:', Gen_loss/len(train_loader.dataset),
          ' Train D Loss:', Dis_loss/len(train_loader.dataset))

    model_g.eval()
    model_d.eval()
    Gen_score = 0
    Dis_score = 0
    for imgs, lbls in tqdm(test_loader, desc=f'[eval]epoch:{epoch}'):
        bs = imgs.shape[0]
        T_imgs = imgs.view(bs, input_size).to(device)
        sample = torch.randn(bs, latent_size).to(device)

        F_imgs = model_g(sample)

        F_Dis = model_d(F_imgs)
        T_Dis = model_d(T_imgs)

        Gen_score += int(sum(F_Dis >= 0.5))
        Dis_score += int(sum(T_Dis >= 0.5)) + int(sum(F_Dis < 0.5))

    print(f'epoch:{epoch}|Test G Score:', Gen_score/len(test_loader.dataset),
          ' Test D Score:', Dis_score/len(test_loader.dataset)/2)

    model_g.train()
    model_d.train()

    model_g.eval()
    noise = torch.randn(1, latent_size).to(device)
    gen_imgs = model_g(noise)
    gen_imgs = gen_imgs[0].view(28, 28)
    plt.matshow(gen_imgs.cpu().detach().numpy())
    plt.show()
    model_g.train()

    torch.save(model_g.state_dict(), modelname[0])
    torch.save(model_d.state_dict(), modelname[1])

sample = torch.randn(1, latent_size).to(device)
gen_imgs = model_g(sample)
gen_imgs = gen_imgs[0].view(28, 28)
plt.matshow(gen_imgs.cpu().detach().numpy())
plt.show()

dataset = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())
index = 0
raw = dataset[index][0].view(28, 28)
plt.matshow(raw.cpu().detach().numpy())
plt.show()
raw = raw.view(1, 28*28)
result = model_d(raw.to(device))
print('该图为真概率为:', result.cpu().detach().numpy())

运行结果展示:

epoch:0|Train G Loss: 0.6875885964711507  Train D Loss: 0.900547916316986
[eval]epoch:0: 100%|██████████| 313/313 [00:01<00:00, 156.80it/s]
epoch:0|Test G Score: 0.508  Test D Score: 0.746
[train]epoch:1: 100%|██████████| 1875/1875 [00:16<00:00, 115.81it/s]
[eval]epoch:1:   0%|          | 0/313 [00:00

第一个图是刚开始训练时的生成数据,第二个图是训练到50个epochs生成的数据,第三个图是训练到100epochs生成的数据。

AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第5张图片AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第6张图片AIGC:使用生成对抗网络GAN实现MINST手写数字图像生成_第7张图片

3 总结

生成模型是深度学习领域难度较大且较为重要的一类模型。生成对抗网络能够在半监督或者无监督的应用场景下进行生成任务的学习。目前而言,生成对抗网络在计算机视觉、自然语言处理等领域取得了令人惊叹的成果。生成对抗模型是近年来复杂数据分布上无监督学习最具前景的方法之一。

你可能感兴趣的:(AIGC,AIGC,生成对抗网络,人工智能)