[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121878299


目录

第1章 生成对抗网络GAN概述与主要应用

第2章 生成对抗网络GAN网络的结构原理

2.1 GAN网络的目标达成器:生成网络G

2.2 GNN网络的助教:判决网络

第3章 GAN网络的模型训练:G网络和D网络如何协同学习,共同进步

3.1 来自大自然的启示

3.2 “对抗”的来源

3.3 GNN网络的模型

3.5 G网络的训练

3.6 D网络的训练

3.7 人脸迭代示例

3.8 代码示例

3.9 数据集

附录:GAN的原理视频讲解推荐


第1章 生成对抗网络GAN概述与主要应用

https://blog.csdn.net/HiWangWenBing/article/details/121881726https://blog.csdn.net/HiWangWenBing/article/details/121881726

第2章 生成对抗网络GAN网络的结构原理

2.1 GAN网络的目标达成器:生成网络G

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第1张图片

给定一个输入向量,经过中间的“生成”升级网络,自动生成(输出一个)有一个有意义的图片或文本,而不同于只对输入向量进行分类的CNN网络。这就是网络训练后的能力或效果!!!

这就是生成对抗网络GAN中的第一个关键词“生成“

(1)输入是什么?

输入是任意长度的向量,长度越大,表达的原始信息就越多。输入向量可以是:

  • 可以是一个随机产生的向量(这种情况只适合学习,没有实际用途)
  • 有意义的图片:根据有意义的输入图片,经过神经网络变换后,生成带有新特征的图片或文本(增加新特征、修改原有特征、减少某些特征),能够实现图片到图片,或图片到文本的转换。
  • 有意义的文本:根据有意义的输入文本,经过神经网络变换后,生成带有新特征的新文本或图片(增加新特征、修改原有特征、减少某些特征),能够实现文本到文本,或文本到图片的转换。

(2)输出是什么?

  • 有意义的图片:根据有意义的输入图片,经过神经网络变换后,生成带有新特征的图片或文本(增加新特征、修改原有特征、减少某些特征),能够实现图片到图片,或图片到文本的转换。
  • 有意义的文本:根据有意义的输入文本,经过神经网络变换后,生成带有新特征的新文本或图片(增加新特征、修改原有特征、减少某些特征),能够实现文本到文本,或文本到图片的转换。

备注:输出不是简单的分类号!!!

通过改变输入向量,就可以生成不同特征的输出图片或文本:

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第2张图片

(3)中间的G网络是什么?

中间G网络可以任意的,GNN网络本身并没有对它进行限定。

  • 全连接网络FC
  • 卷积网络CNN
  • 时序网络RNN

中间的G神经网络是GNN的两大核心之一。GNN网络的“生成”功能,正是通过G网络来完成的。

(4)G网络的标签是什么?是通过标签训练得到的吗?

G网络是不是与FC, CNN或RNN一样,需要给输出图片指定标签呢?

实际上,“生成网络G”输出的不是一个简单的分类号,而是一个有意义的图片或一段文本。

因此,无法给G网络指定一个简单的标签,G网络自身是没有标签的。

这就是GNN网络称为“无监督式学习”网络的原因。

(5)G网络的能力形成来源是什么?

G网络没有标签,那么G网络能够自学成才吗?!!!

答案是否定的,因为G网络的输出,不是简单的分类:

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第3张图片

 G网络的输出,是有意义的复杂的图片或文本,自己是无法通过算法自学的!!!

那么G网络的能力是哪里来的?是如何获得才能力的呢?

由于输出过于复杂,即使给他标签好的图片,它也无法通过自身的自学完成!!!

需要给G网络指定能够对G网络的输出进行判别的新网络: 判别网络。

就像给孩子指定一个老师一样。

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第4张图片

如果说,FC, CNN, RNN网络的指导老师是一个标签。

那么GNN网络中的G网络的指导老师,也是一个网络, 称为判决网络D,而不是简单的标签。 

G网络负责生成图片。

D网络负责对生成的图片,根据已有的样本图片进行判决,是否与真实的样本的特征一致。

备注:这里提到的D网络,并不是预先训练好的模型,也是需要与G网络一起进行学习,因此,更准确的讲,D网络,更新是一个美术的助教,它只负责根据标准答案批改作业,自己并非先知先觉,而是要与G网络一起,同步学习进步。

那么我们就来看看D网络的详情。

2.2 GNN网络的助教:判决网络

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第5张图片

(1)判别网络D的输入是什么?

  • 真实的图片或文本
  • G网络生成的图片或文本

备注:

真实的图片和G网络的生成图片,并不是同时送入到D网络中的。

而是各自独立的送入到D网络中进行判别的。

(2)判别网络D的输出是什么?

判别的输出,并不是图片的分类,如1000分类。

判别的输出是真实图片的概率[0, 1], 1: 表示,判别该图片为100%真实图片,0:表示判别输入图片100%不是真实图片,而是生成图片或其他图片,即是真实图片的概率为0%。

如果硬要把D网络归为分类网络这个大类的话,那么D网络实际上是:判断是不是真实图片的2分类网络。

对G和D网络进行训练的目标:

  • 对真实图片的判决接近于1
  • 对生成图片的判决接近于1

也就是说,训练的目的是使得生成图片具备与真实图片完全相同的内在特征。

当然,生成图片,除了具备与真实图片完全相同的内在特征外,还具备自身原有的特征。

这就是生成图片的重要作用于意义!!!

否则的话,就是简单的复制,简单的复制是没有意义!!!

GNN网络的意义:

生成的图片与真实图片具备完全相同的内在特征,同时具备自己的新的特征,这就是创作!!!

(3)判别网络D的类型是什么?

D网络可以任意的,GNN网络本身并没有对它进行限定。

  • 全连接网络FC
  • 卷积网络CNN
  • 时序网络RNN

(4)判别网络D的标签是什么?

D网络的标签是比较特别的,它并不是真实图片或生成图片的分类,比如1000分类等。

D网络的标签只有1或0。

  • 1:表示输入图片是真实图片,凡是真实图片,所有的图片的标签都是1.
  • 0:表示输入图片是生成图片,凡是真实图片,所有的图片的标签都是0.

由于这个标签的特殊性,实际上,真实图片是没有认为的标签的,输出图片也没有认为的标签。

这个标签完全是在模型训练的时候,由程序员自己设定的,并不是数据集里面提供的标签!!!

数据集里是没有任何标签的,就仅仅是原始的输入图片或文本而已。

因此,GNN网络,被归为“无监督”机器学习。

(5)判别网络D在什么时候其作用?

D网络在训练时候起作用。 在模型训练好后,起助教作用的D网络的任务就完成。

后续的图片生成,就全靠G网络了。D网络是爱莫能助,顶多用D再判断一下,生成图片是否符合真实图片的要求,如果不符合,则给G网络提个意见:“继续培训吧“?如果D网络说,当前的水平已经满足要求了,我不想再学了,D网络其实也是没有办法的,帮不了什么忙。

接下来,看看D网络是如何帮助G网络完成能力训练的!

第3章 GAN网络的模型训练:G网络和D网络如何协同学习,共同进步

3.1 来自大自然的启示

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第6张图片

 捕食者与伪装者都在各自独立迭代、进化。

生成者或伪装者:通过自我进化(W, B参数的更新)使它的外形图形输出,能够尽可能骗过判决者,捕食者,让捕食者尽可能的判断其自身不是伪装者,而是大自然真实树枝或树叶。

判决者或捕食者:通过自我进化(W, B参数的更新)尽可能提高自己识别伪装者的图像与真实树枝、树叶的差别,尽可能的把伪装者的图片,判决成伪装者,而不是真实的树枝或树叶。

经过一代代的迭代进化,生成者或伪装者的输出图像越来越接近真实的树枝或树叶。

经过一代代的迭代进化,判决者或捕食者越来越能够区分在细微的程度区分伪装者与真实的树枝或树叶的差别。

如果进化到某一天,无论判决者或捕食者再怎么进化,也无法区分生成者或伪装者的输出图像。

这就完成真实的以假乱真的程度了!!!!

3.2 “对抗”的来源

GNN网络的训练:就是要达到上述类似的目标。

这正是生成对抗网络中的“对抗的由来。

当然,生成对抗网络最初的作者,举的案例是:钱币的造假者与警察(银行)之间的对抗。

而大自然捕食者与被捕食者之间的对抗更能体现GNN网络的对抗与协助的对立与统一的关系,而不仅仅是“对抗“的关系。

3.3 GNN网络的模型

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第7张图片

问题:

为什么生成网络G不能自己学习? 

为什么判决网络不能自己生成图片,还需要通过生成网络生成图片?

为什么学生需要老师?

为什么老师不自己编程,而需要教会学生编程呢?

3.5 G网络的训练

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第8张图片

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第9张图片

(1)1.1 生成随机输入向量

(2)1.1 使用当前的G网络生成图片

(3)1.2 使用当前的D网络对图片进行预测,得到输出output_fake

(4)1.3 求output_fake与1之间的loss_fake = BCELoss(output_fake, 1), 其中1表示:真实图片。

(5)1.3 对G网络进行反向求导、迭代,使得loss最小,通过迭代G网络的W, B参数,使得G网络的输出图片被伪装成真实图片。也集是G网络负责伪装。

loss算法和迭代算法如下:

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第10张图片

3.6 D网络的训练

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第11张图片

 [人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第12张图片

(1)2.1 对伪装后的图片进行预测,得到output_fake。

(2)2.1 求output_fake与0之间的loss_fake = BCELoss(output_fake, 0), 其中0表示:生成图片。此时,D网络需要有能力把生成图片给鉴别出来,即loss_fake要尽可能的小。

(3)2.2 对真实的图片进行预测,得到output_real。

(4)2.2 求output_real与1之间的loss_real = BCELoss(output_real, 1), 其中1表示:真实图片。此时,D网络需要有能力把真实图片给鉴别出来,即loss_real 要尽可能的小。

(5)2.3 计算D网络总的loss=loss_fake + loss_real,然后对D网络反向求导、梯度迭代,目的是提升D网络的鉴别能力,把真实图片鉴别成1, 把生成图片鉴别成0.

loss算法和迭代算法如下:

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第13张图片

通过对G网和D网络的各种更新,就完成了一个batch的一次迭代。然后启动下个batch的迭代。 

3.7 人脸迭代示例

(1)迭代100次后的人脸输出:轮廓特征

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第14张图片

(2)迭代1000次后的人脸输出:有眼睛、脸部特征

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第15张图片

(3)迭代2000次后的人脸输出:嘴巴特征

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第16张图片

(4) 迭代5000次后的人脸输出:水汪汪大眼睛特征

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第17张图片

 (5)迭代10000次后的人脸输出:清晰度提升

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第18张图片

 (6)迭代50000次后的人脸输出:最终输出

[人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)_第19张图片

3.8 代码示例

for epoch in range(n_epochs):
    # 读取一个batch的数据
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        # 获取当前dataloader imgs的batch size
        tmp_batch_size = imgs.size(0) 
        
        # 构建训练所需要的标签
        # 真实图片的标签全部为1
        valid = Variable(Tensor(tmp_batch_size, 1).fill_(1.0), requires_grad=False)
        # 生成图片的标签全部为0
        fake  = Variable(Tensor(tmp_batch_size, 1).fill_(0.0), requires_grad=False)

        # -----------------
        #  Train Generator 生成网络
        # -----------------
        # 复位G网络的梯度值
        optimizer_G.zero_grad()

        # Sample noise as generator input
        # 生成与真实图片相同batch的输入向量
        # 这里的输入是latent_dim=100长度的一维向量,向量的值为随机值
        # 随机值在每次迭代时会发生变化吗?
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        
        # Generate a batch of images
        # 用相同的生成网络,针对不同的批样本输入,生成一批不同的图片
        gen_imgs = generator(z)
        
        # 对生成的图片进行预测
        y_gen = discriminator(gen_imgs)
        
        # Loss measures generator's ability to fool the discriminator
        # 目标是:使得当前G网络生成的图片,骗过当前的判决网络D, 判决为真实图片
        # 每次迭代:G网络进化一点点,即使得g_loss降低一点点。
        # 由于每次迭代, D网络也在进化,导致g_loss再提升一点点
        # g_loss反应的是:使用当前的D网络,判断G网络生成的图片,是不是真实图片中的一个
        # g_loss越小,生成的图片越接近真实的图片中的一个
        g_loss = adversarial_loss(y_gen, valid)

        # 求G网络的梯度
        g_loss.backward()
        
        # 反向迭代生成网络G, 只迭代G网络的W,B参数
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator 判决网络
        # ---------------------
        # 判决网络通过提高鉴别能力,朝着尽可能能够区分真实图片和生成图片的方向迭代、进化
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        # 用真实图片进行预测
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        
        # 判决对生成图片进行判决: 尽可能识别生成图片与真实图片的差别
        # 优化判决网络,尽可能使得判决网络把真实图片判定为1, 因此使用了标签1
        y_real = discriminator(real_imgs)
        real_loss = adversarial_loss(y_real, valid)
        
        # 优化判决网络,尽可能使得判决网络把生成图片判定为0,因此使用了标签0
        y_fake = discriminator(gen_imgs.detach())
        fake_loss = adversarial_loss(y_fake, fake)
        
        # 对loss进行叠加:真实图片的判决与1的距离,生成图片的判决与0的距离
        d_loss = (real_loss + fake_loss) / 2

        # 判决网络D的反向求导
        d_loss.backward()
        
        # 反向迭代判决网络D
        optimizer_D.step()
        
        
        # 每隔sample_interval=400个生成样本,存储一个到文件中
        batches_done = epoch * len(dataloader) + i
        
        if batches_done % sample_interval == 0:
            #save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
            #存储一个batch中的所有生成图片,每行8张图片
            save_image(gen_imgs.data[:], "images/%d.png" % batches_done, nrow=8, normalize=True)
            
            print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
            print(z)

print("DONE")

3.9 数据集

(1)判决网络D的真实图片的数据集

采用不同的数据集,就会产生不同的输出。这个数据集是关键。

(2)判决网络D的生成数据集

来自于生成网络的临时输入

(2)生成网络G的数据集

  • 本文主要采用随机值作为生成网络G的数据集,生成网络最终图片的输出完全取决于判决网络D的真实图片数据集。
  • 如果采用其他数据集作为生成网络G的输入,则生成网络最终图片的输出,不仅仅取决于判决网络D的真实图片数据集,还具备生成网络输入数据集的特征!!!!

附录:GAN的原理视频讲解推荐

​​​​​​不愧是清华大佬,把GAN生成对抗网络讲得如此清新脱俗,简单明了!理论讲解及项目实战(建议收藏)(深度学习)_哔哩哔哩_bilibilibaidu需要视频学习资料的伙伴,+小姐姐威信:THY89521 免费领取! 还可领取一份200G人工智能学习资料礼包(内含:两大Pytorch、TensorFlow实战框架视频、图像识别、OpenCV、计算机视觉、深度学习与神经网络等等等等视频、代码、PPT以及深度学习书籍 !你想要的里面都有!)https://www.bilibili.com/video/BV1h44y1a7w3


 作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121878299

你可能感兴趣的:(人工智能-深度学习,生成对抗网络,深度学习,神经网络)