【cycleGAN代码学习笔记】

原文地址:cycleGAN模型构建及代码解读及细节_HNU_刘yuan的博客-CSDN博客_cyclegan代码

cycleGAN简介

cycleGAN是一种由Generative Adversarial Networks发展而来的一种无监督机器学习,是在pix2pix的基础上发展起来的,主要应用于非配对图片的图像生成和转换,可以实现风格的转换,比如把照片转换为油画风格,或者把照片的橘子转换为苹果、马与斑马之间的转换等。因为不需要成对的数据集就能够转换,所以在数据准备上会简单很多,十分具有应用前景。

cycleGAN名字中之所以有一个cycle,我觉得应该是原图经过一种生成网络转换后得到另一种风格的图片,然后还要经过另一种生成网络转换后尽可能的接近原图,形成了一个循环,所以被称为cycleGAN。

所以有:
AtoBtoA = G_BtoA(G_AtoB(real_A)) 从A风格转换到B风格,又转换为A风格
BtoAtoB = G_AtoB(G_BtoA(real_B)) 从B风格转换到A风格,又转换为B风格

cycleGAN中的网络

cycleGAN由两个生成网络和两个判别网络构成

  1. G_AtoB() 看作是风格A向风格B的生成网络
  2. G_BtoA() 看作是风格B向风格A的生成网络
  3. dis_A() 看作是判别输入图片是否属于风格A的判别网络                                                       
  4. dis_B() 看作是判别输入图片是否属于风格B的判别网络
  5. AtoB = G_AtoB(real_A) 看作是real_A经过生成网络转换得到的风格B的照片
  6. BtoA = G_BtoA(real_B) 看作是real_B经过生成网络转换得到的风格A的照片

其中G_AtoB()和G_BtoA()的输入为[B, C, W, H],即batchsize, channels, width, height,输出一般与输入相同; 

其中dis_A()和dis_B()的输入为[B, C, W, H],即batchsize, channels, width, height,输出的维度是[B, 1],里面的是经过sigmoid函数输出的,所以取值范围在[0, 1]进行分类。

 生成器由三个部分组成:

  1. 编码器(由三层卷积网络构成,并进行归一化;使用了残差块,减弱梯度消失,使网络可以自己自适应地调节层数的深浅,变得更深的同时更平滑
  2. 转换器
  3. 解码器(用到反卷积(逆卷积)和卷积层,经过残差结构,第一、二层反卷积,第三层卷积

辨别器:用的是5层卷积,将通道数减为1,最后进行池化平均,再reshape成[batchsize 1]

损失函数:
cycleGAN中用到了两种损失函数,

  1. MSE,应用在标签中,用来判断discriminator输出的label和真实lable之间的loss。        

    gen_AtoB中,Dis_B判断AtoB生成的图片与真实标签之间的loss
    gen_BtoA中,Dis_A判断BtoA生成的图片与真实标签之间的loss
    Dis_A中 real_A与真实标签之间的loss | | B2A与虚假标签之间的loss
    Dis_B中 real_B与真实标签之间的loss | | A2B与虚假标签之间的loss

  2. L1,应用在图片中,衡量图片与图片之间的loss。                                                                  real_A和 A2B2A之间
    real_B和 B2A2B之间
    real_A和 B2A(real_A)
    real_B和 A2B(real_B)

其中的第三种和第四种情况可以理解为:经过生成该图片风格的生成器生成的图片应该尽量与原图保持一致。也被成为identity loss。可以理解成生成器Gen_AtoB负责x域(domain)到y域图像的生成,如果输入y域的图片,输出仍然是y域的图片,比较符合直觉,用的是L1函数。

你可能感兴趣的:(GAN,学习,人工智能,计算机视觉,GAN,生成对抗网络)