DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进

声明

GAN是近年深度学习领域的利刃,革新的GAN总给人以惊叹的表现。本文的所有讲解均为本人通过相关的书籍、博客,然后加上自己两周的体会所总结的。本文所述原理由python实现,注解详细,若有不妥之处,多多指教。(从原理到代码(100%可运行,超详细),一个博客就够了)

  • DCGAN原理

DCGAN全称Deep Convolutional Generative Adversarial Networks,即深度卷积生成对抗网络。大家都知道GAN的核心思想就是博弈。好比现在有D,G二人进行拳击比赛,D想把G打败,但是G也想把D打败,他们相互竞争,相互加强,最后每人都非常强大。有了上面的例子,GAN的原理就好理解多了:GAN模型包括生成网络G鉴别网络D,生成网络的目的是生成假的图像使鉴别网络无法鉴别真假,鉴别网络的目的是努力分辨真假图像,练就火眼金睛。最终直到鉴别网络分辨不出生成网络生成的以假乱真的图像为止。

直接上一张我认为可以很好帮助理解的图片,看不懂没关系,我慢慢讲解,你肯定能看懂!

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第1张图片

 

  • 生成网络模型

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第2张图片

生成网络(Generator)接收一个随机噪声z,然后通过上采样(up-sampling)生成图像G(z)。上采样主要采用反卷积算法,G接收一个100-d随机噪声z,经过Project and reshape(实际上就是一个全连接层),转化为一个4*4*1024的feature map,然后经过多个反卷积层,生成大小为64*64*3的图像。(声明:官方给的生成网络只是为了帮助理解原理,并不是说DCGAN的生成网络就是一个反卷积网络,生成网络根据个人不同需求可以替换,随后本人将给出改进过后的生成网络)

  • 鉴别网络模型

 

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第3张图片

(声明:图形为FCN_VGG16的下采样过程,后面是我手画的,有点丑,但是不影响理解原理。卷积层的个数和全连接层个数没有按照下面的代码制作,具体层数见下面分析代码,重点是不影响理解。)

鉴别网络(Discriminator)的输入为一张图片,经过下采样(down-sampling,卷积运算),再接全连接层处理,送入sigmoid函数,输出真假概率。

  • 细节注意

1、G,D网络不采取任何池化处理;

2、G,D网络每一层均使用批标准化处理(Batch-Normalization);

3、在G网络中,激活函数除了最后一层外,都是用Relu函数,最后一层使用tanh函数;

4、D网络中,激活函数均使用Leaky Relu函数。

(激活函数的讲解,有的博主写的很好,【传送门】https://blog.csdn.net/kangyi411/article/details/78969642)

  • 损失函数的解析

论文上的损失函数的定义比较难以理解,集成度有点高,我将它进行分开解读:

生成网络的损失函数:

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第4张图片

鉴别网络的损失函数:

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第5张图片

 D(x):真实图片的概率;G(z):G生成的假图像;D(G(z)):G生成假图像的概率。

G希望生成的图像以假乱真,所以希望D(G(z))更大,所以用1-D(G(z))的巧妙表达,这样损失值应越越好。

D希望它的鉴别能力超强,所以D(G(z))应该很小,D(x)应该很大,一大一小,变化不定,所用采用1-D(G(z))的巧妙表达,这样损失值应越越好。(引入1-D(G(z))的巧妙表达,只是为了损失的巧妙表达,根据损失的定义是推不出来的,巧妙。)

  • 训练网络(敲黑板,核心,重点注意0和1的变化,博弈开始)

值得一提的是,两个网络独立训练,只不过生成网络G的网络健壮性需要D网络的训练结果去鉴别。是不是很迷?

再次上图:

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第6张图片

【开始博弈】

鉴别网络D的火眼金睛练就之路

首先,我们只有真的样本集,连标签都没有。我们先用G网络生成假的样本,这样我们就有了真假样本集。我们人为的知道哪些是真,哪些是假,人为打标签真的是1,假的是0。就这样迭代训练,损失应该是真假损失两部分之和,损失反向传播,优化。

G网络的健壮性自己无法判断,所以要借助D网络鉴别。这时重点来了,人为设定此时的假图像是1。为什么要这样呢?是不是很迷?是不是?当时我也想了很长时间。G网络实际坏的很,为了骗D网络,G网络就硬说自己生成的就是真图像,这时D网络就鉴别假图像,如果D网络是真孙悟空,那么G网络再怎么骗,没用,你说你是真的就是真的?我D网络说你就是假的。这时G不服,我再去训练,我再修炼,拿着更真的假图像去骗D,D再次进行鉴别,如果D发现自己的判断能力不足,D就再次回炉训练自己的火眼金睛,增强自己的鉴别能力。就这样互相伤害,互相增强,最终达到:G生成的图片就算D拿出吃奶的劲都没法鉴别真假,那就不0也不1,就算你0.5好了(理想情况下,实际训练一般达不到,只能非常非常趋近0.5)。

注意,在训练G网络,实际上是一个G-D的过程,但是后半部分的D(G(z))不参与训练,只训练生成网络G,我们用的是上一轮训练好的D网络来鉴别这一轮的G,所以无需再此训练。所以本人先运行G网络,再运行D网络并训练,再训练G,也可以根据情况适当增加参数更新次数(我是D训练2次,G训练1次)。

趁热打铁看代码:

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,labels=tf.ones_like(d_logits_real) * (1 - smooth))) 

d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.zeros_like(d_logits_real)))

d_loss = d_loss_real + d_loss_fake 

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake))
  • 改进的DCGAN网络

本人之前学习的是FCN,自然我要引入FCN当做生成网络。先上设计图:

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第7张图片

DCGAN生成对抗网络原理代码详细讲述(巨微)以及网络改进_第8张图片

我有真的样本集和标签集,我这样处理D、G的输入,我让真的样本集做生成网络的输入,生成特征图(分割图);让G的输出特征图作为假的样本集合标签集送入D网。

注明  DCGAN的D和G网络代码请关注下一篇博客。

                                                                                                                                      (河南理工大学---李林祥)

 

你可能感兴趣的:(生成对抗网络GAN)