【深度学习】生成式对抗网络的损失函数的理解

生成式对抗网络即GAN由生成器和判别器组成。原论文中,关于生成器和判别器的损失函数是写成以下形式:
【深度学习】生成式对抗网络的损失函数的理解_第1张图片
首先,第一个式子我们不看梯度符号的话即为判别器的损失函数,logD(xi)为判别器将真实数据判定为真实数据的概率,log(1-D(G(zi)))为判别器将生成器生成的虚假数据判定为真实数据的对立面即将虚假数据仍判定为虚假数据的概率。判别器就相当于警察,在鉴别真伪时,必须要保证鉴别的结果真的就是真的假的就是假的,所以判别器的总损失即为二者之和,应当最大化该损失。由于判别器(警察)鉴别真伪的能力随着训练次数的增加越来越高,生成器就要与之“对抗”,生成器就要相应地提高“造假”技术,来迷惑判别器。第二个式子为第一个式子的第二项,含义相同,只不过对于生成器应当最小化该项,生成器当然希望辨别器将虚假数据仍判定为虚假数据的概率越低越好,即将虚假数据误判定为真实数据的概率越大越好,即最大化log(D(G(zi)))损失函数。所以二者相互提高或者减小自身的损失,以不断互相对抗。
【深度学习】生成式对抗网络的损失函数的理解_第2张图片
我用pytorch搭建了一个简易的GAN,没用卷积层,只是单纯的全连接层,利用mnist图像作为真实数据,随机生成100维的随机噪声作为生成器的输入,20次迭代的最终结果如上图,可以看出GAN多多少少能有些真实图像的大概轮廓。

你可能感兴趣的:(深度学习)