GAN对抗神经网络

前言:因为项目要求要学习一下GAN来做图像修复,因此作者写了这样的一篇博客来深度学习一下GAN对抗神将网络。


一、GAN简介

GAN(Generative Adversarial Network)对抗神经网络是一种深度学习模型,由一对神经网络组成:生成器和判别器。

生成器的目的是学习生成数据的分布,它将随机噪声作为输入并生成与真实数据相似的输出。判别器的目的是学习将生成器生成的数据与真实数据区分开来。生成器和判别器在训练过程中相互对抗,生成器尝试生成越来越真实的数据来欺骗判别器,而判别器尝试识别哪些数据是真实的哪些是生成的。

在训练过程中,生成器和判别器相互对抗,通过反复调整参数,生成器可以生成高质量的数据,并欺骗判别器难以判断哪些数据是真实的。GAN已被广泛应用于图像生成、语音合成、自然语言处理等领域,并产生了许多有趣的应用,如DeepFake等

二、GAN网络主要成分

GAN网络的主要成分包括:

  1. 生成器(Generator):生成器是一个神经网络模型,它接收随机噪声作为输入,并输出一些与原始数据相似的数据。生成器的目标是尽量准确地模拟原始数据的分布。

  2. 判别器(Discriminator):判别器也是一个神经网络模型,它的输入是一组数据,输出一个概率值,用于判断该数据是真实数据还是生成器生成的数据。判别器的目标是尽可能准确地区分真实数据和生成器生成的数据。

  3. 损失函数(Loss Function):GAN网络的损失函数由两部分组成。一部分是生成器的损失函数,它衡量生成器生成的数据与真实数据之间的差异程度;另一部分是判别器的损失函数,它衡量判别器预测的概率是否准确。

  4. 优化器(Optimizer):GAN网络使用反向传播算法训练模型,优化器的作用是根据损失函数的结果来更新模型的参数,使得生成器和判别器不断优化,让生成器生成的数据更加逼真,让判别器可以更好地区分真实数据和生成器生成的数据。

  5. 数据集(Dataset):GAN网络需要大量的数据来进行训练,训练数据集的质量和数量对于GAN网络的性能影响非常大。通常情况下,数据集应该包含真实数据和标签,并与生成器的输出数据进行比较。

三、简化网路结构    

GAN对抗神经网络_第1张图片

GAN对抗神经网络_第2张图片         作者这里在学习的时候手绘了一下网络结构,这里作者概述一下整体的结构吧。首先我们要明确用这个网络的前提下有两类数据,一类是real data,一类是fake data,然后判别器D要尽可能的认出real data,生成器要尽可能让fake data被D认定为real data。其实总的来说就是二者博弈的关系的关系。首先就是生成器(G),它的作用其实是接受噪音Z用来生成与真实数据相似的数据,注意:生成器生成的数据的size与real data 中的数据大小是一致的。然后就是判别器(D)。最简单GAN的话这里做的就是个二分类任务,就是判别真伪。

        为了更加形象的解释gan作者这里举个例子。假设你要去买名表,但是呢你从来没买过名表,你很难判断表的真伪,而买名表的经验可以防止你被奸商欺骗。当你开始讲大多数名表标记为假表(被骗之后),卖家就开始生产高仿表。然后你再去买表。二者相互博弈,你的经验在增加,卖家的造假经验也在提高。最后生成器生成的东西就达到与真实的东西尽可能的接近了。

四、各个组成部分

4.1生成器

        GAN的生成器主要是用来生成与真实数据相似的数据。具体来说,生成器接收一个随机噪声向量作为输入,并通过神经网络生成一些数据,这些数据与真实数据有相同的特征、分布和模式。生成器的目标是在训练过程中逐渐学习到真实数据的分布和模式,从而生成出与真实数据相似的数据。生成器的生成结果将被送往判别器进行判断,判断其是否具有真实数据的特征和分布,如果判断为真实数据,则生成器达到了预期目标。在训练过程中,生成器的目标是优化生成的数据,使其尽可能逼近真实数据的分布,从而让判别器难以区分真实数据和生成的数据。

        作者这里简单写了个D的网络

class Generator(nn.Module):
    def __init__(self):#初始化
        super(Generator, self).__init__()#这里做的是一个重写操作
        self.fc1 = nn.Linear(100, 256)#全连接层#输入100维,输出256维
        self.fc2 = nn.Linear(256, 512)#全连接层#输入256维,输出512维
        self.fc3 = nn.Linear(512, 1024)#全连接层#输入512维,输出1024维
        self.fc4 = nn.Linear(1024, 28*28)#全连接层#输入1024维,输出28*28维
        self.relu = nn.ReLU()#激活函数,这里使用的是relu函数,用于增加网络的非线性
        self.tanh = nn.Tanh()#激活函数,这里使用的是tanh函数,用于增加网络的非线性
    def forward(self, x):#前向传播
        x = self.relu(self.fc1(x))#输入x,经过全连接层,再经过relu激活函数
        x = self.relu(self.fc2(x))#输入x,经过全连接层,再经过relu激活函数
        x = self.relu(self.fc3(x))#输入x,经过全连接层,再经过relu激活函数
        return self.tanh(self.fc4(x))#输入x,经过全连接层,再经过tanh激活函数
    

具体的里面的数值根据你们图像的参数来。 

4.2判别器

        GAN的判别器是用来判断生成器产生的数据是真实数据还是假数据。具体来说,判别器接收一个数据作为输入,并通过神经网络判断其是否为真实数据。判别器的目标是在训练过程中逐渐学习到真实数据的分布和模式,从而能够区分真实数据和生成器生成的数据。在训练过程中,判别器的目标是最大化正确分类真实数据和错误分类生成的数据的概率,从而让生成器生成的数据更接近真实数据的分布。GAN的生成器和判别器是相互博弈的,生成器不断优化生成的数据,使其更接近真实数据的分布,判别器则不断学习真实数据的分布和模式,从而能够更准确地区分真实数据和生成的数据。依靠双方博弈,不断优化生成器和判别器的训练目标,最终可以使生成器生成的数据逼近真实数据的分布。  

class Discriminator(nn.Module):
    def __init__(self):#初始化
        super(Discriminator, self).__init__()#这里做的是一个重写操作
        self.fc1 = nn.Linear(28*28, 1024)#全连接层#输入28*28维,输出1024维
        self.fc2 = nn.Linear(1024, 512)#全连接层#输入1024维,输出512维
        self.fc3 = nn.Linear(512, 256)#全连接层#输入512维,输出256维
        self.fc4 = nn.Linear(256, 1)#全连接层#输入256维,输出1维
        self.relu = nn.ReLU()#激活函数,这里使用的是relu函数,用于增加网络的非线性
        self.sigmoid = nn.Sigmoid()#激活函数,这里使用的是sigmoid函数,用于增加网络的非线性
    def forward(self, x):#前向传播
        x = self.relu(self.fc1(x))#输入x,经过全连接层,再经过relu激活函数
        x = self.relu(self.fc2(x))#输入x,经过全连接层,再经过relu激活函数
        x = self.relu(self.fc3(x))#输入x,经过全连接层,再经过relu激活函数
        return self.sigmoid(self.fc4(x))#输入x,经过全连接层,再经过sigmoid激活函数,输出结果为0-1之间的概率,sigomid函数是用来做二分类的,0.5以上为1,0.5以下为0

        作者也写了个关于判别器的类。 最后是输出成一个一维向量哦,我这个例子里的。

4.3测试与训练函数

        这里作者知识做简单的阐释,说一下测试与训练函数要干啥、

# 训练函数
def train(G, D, G_optimizer, D_optimizer, loss_func, train_loader, epoch):
    G_losses = []#记录生成器的loss
    D_losses = []#记录判别器的loss
    for step, (x, y) in enumerate(train_loader):#枚举数据
        b_x = x.view(-1, 28*28)#将数据转换为28*28维
        b_y = y#标签
        b_z = torch.randn((x.shape[0], 100))#生成随机噪声
        G_result = G(b_z)#生成器生成的结果
        D_real = D(b_x)#判别器判别的真实数据
        D_fake = D(G_result)#判别器判别的生成数据
        D_real_loss = loss_func(D_real, torch.ones_like(D_real))#判别器判别真实数据的loss
        D_fake_loss = loss_func(D_fake, torch.zeros_like(D_fake))#判别器判别生成数据的loss
        D_loss = D_real_loss + D_fake_loss#判别器的loss
        D_optimizer.zero_grad()#判别器梯度清零
        D_loss.backward()#反向传播
        D_optimizer.step()#判别器梯度下降
        G_result = G(b_z)#生成器生成的结果
        D_fake = D(G_result)#判别器判别的生成数据
        G_loss = loss_func(D_fake, torch.ones_like(D_fake))#生成器的loss
        G_optimizer.zero_grad()#生成器梯度清零
        G_loss.backward()#反向传播
        G_optimizer.step()#生成器梯度下降
        G_losses.append(G_loss.item())#记录生成器的loss
        D_losses.append(D_loss.item())#记录判别器的loss
        if step % 100 == 0:#每100次输出一次结果
            print('Epoch: ', epoch, '| Step: ', step, '| G loss: ', G_loss.item(), '| D loss: ', D_loss.item())
    #保存模型
    torch.save(G.state_dict(), './model/G.pth')
    torch.save(D.state_dict(), './model/D.pth')
    return G_losses, D_losses#返回生成器和判别器的loss

4.4主函数(补充一下损失函数)

        因为作者没有规定数据集的loader和loss function这里同意放到main函数中,loss函数基本是用BCELoss

if __name__ == '__main__':
    G = Generator()#生成器
    D = Discriminator()#判别器
    G_optimizer = Adam(G.parameters(), lr=0.0001)#生成器的优化器
    D_optimizer = Adam(D.parameters(), lr=0.0001)#判别器的优化器
    loss_func = nn.BCELoss()#loss函数
    train_loader = Data.DataLoader(dataset=torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=64, shuffle=True)#训练数据
    for epoch in range(10):#训练10轮
        G_losses, D_losses = train(G, D, G_optimizer, D_optimizer, loss_func, train_loader, epoch)#训练
        test(G, epoch)#测试

这里补充一下

        BCELoss表示的是二元交叉熵损失(Binary Cross Entropy Loss),用于二分类问题。对于每个数据点,BCELoss会计算真实标签(0或1)和模型预测标签(0到1之间的概率值)之间的交叉熵损失。 具体来说,对于每个数据点,BCELoss的公式如下: 

        BCELOSS\left ( o,t \right ) =- \frac{1}{n}\sum(ti*logoi+(1-log)*(1-oi)))

五、总结 

        这里先简单介绍了一下gan的一些构造和构成,作者暂时还不能用这个模型写点东西出来,希望下次我们把数据集弄好后作者可以呈现一个图像修复的GAN网络,而这里所有写的代码是参考手写数据集的数据类型来处理的也就是MINST数据集。不过这篇文章作者想要说明的是GAN对抗神经网络的一些构成和简介,所以没有举比较好的例子,只是在研究里面的一些参数的意义,该做的事情。

        难得作者又开始为了提升自己写文章了,支持一下哦!!!

你可能感兴趣的:(神经网络,生成对抗网络,人工智能)