CycleGAN的介绍与实现效果

本篇文章将会介绍cyclegan的基本原理以及实现的效果。

1.实现效果

首先介绍一下cyclegan的实现的效果,简单来说就是将不同域之间的图像进行转换,而本身的形状保持不变如下图:

CycleGAN的介绍与实现效果_第1张图片

上图是论文给出的效果,可以看出从斑马变成马,本身的状态保持不变,只是改变了自身的颜色。与pix2pix不同的是,pix2pix的训练数据必须是配对的,所以类似于上面的图片很难找到配对的图片,而cyclegan的数据集只有两种,比如上面的数据集就是斑马和马即可。而且在最后实现的效果发现,cyclegan的效果会比pix2pix好一点。

2.网络结构

下图是我在github中找到的一张原理图:

CycleGAN的介绍与实现效果_第2张图片

首先cyclegan的网络有两个鉴别器和两个生成器,这与之前学到的gan会有所不同。两个生成器的作用是:

G_A2B:将真实马的图片变成相同状态的斑马图片(假的)或者将生成的假的马的图片变成斑马。

G_B2A:将真实的斑马图片变成相同形状的马的图片(假的)或者将生成的斑马图片变成马。

D_A:鉴别真实的马或者鉴别生成的马。

D_B: 鉴别真实的斑马或者鉴别生成的斑马。

在这里大家可能看不懂每一个鉴别器和判别器的作用,和网络是怎样实现斑马和马之间的转换,接下来看另一张图:

CycleGAN的介绍与实现效果_第3张图片

 这张图是论文中给出的,它的意思就是X(马)经过生成器G(G_A2B)来生成带有原来形状的斑马图片,紧接着我在用另一个生成器F(G_B2A)来将我刚刚生成的斑马图片还原成之前马的样子,最后两个鉴别器分别来判断生成的斑马和真实的马的真假。反过来原理也是一样,我将Y(斑马)放入生成器(G_B2A)中来生成假的带有原来形状的马,紧接着我在用另一个生成器(G_A2B)把生成出来的马来变成原来形状的斑马,紧接着两个鉴别器分别来判别生成的马和真实的斑马的真假。这就是cyclegan的原理。

CycleGAN的介绍与实现效果_第4张图片

 上图是网络生成器和判别器的网络结构,是在一个老师的课程上看到的,大家可以简单看看。

3.损失函数

损失函数是本篇文章最重要的部分,因为它的损失函数会在之后的很多gan中都会用到,那就是循环损失(Cycle Consistency Loss)。首先介绍一下它的作用。假如我在训练的时候例如马到斑马的训练中,我使用生成器来生成斑马的图片,如果我生成的斑马图片的形状跟原来马的形状相差很大,但是生成的斑马的图像又特别的逼真,那样就失去了意义,所以为了防止这种情况就诞生出了这个损失函数就是说上面的图(b)当马通过两个生成器后变回原来的马,将两个图片相减来计算两者的差距,如果两者距离约小,那就代表两个图片就越相似,这就是循环损失。

 

接着就是正常的生成对抗损失,这部分没有什么好讲的,再就是还有一个损失就是Identity Loss,这部分在论文中没有体现所以就不讲了,主要在特殊训练集中使用,比如艺术作品之类的。

CycleGAN的介绍与实现效果_第5张图片

CycleGAN的介绍与实现效果_第6张图片

 上图是两个生成器的损失函数组成部分。

 

4.代码

那么我们如何在代码中来计算各个损失呢,两个生成器分别计算损失,然后将损失加起来,之后放入一个adam中同时更新两个生成器,判别器也是同样的操作。

CycleGAN的介绍与实现效果_第7张图片

 5.实现效果

我一开始生成的是斑马和马

CycleGAN的介绍与实现效果_第8张图片CycleGAN的介绍与实现效果_第9张图片

 CycleGAN的介绍与实现效果_第10张图片CycleGAN的介绍与实现效果_第11张图片

可以看出来生成的效果很差,当时想要放弃了不想训练,后来老师让我在试一试,于是我索性就换了一个数据集,就是橙子和苹果的,但是我改了一些参数,就是循环损失的参数一开始我设置的是10,后来我感觉两个图片不太像所以我就改成了20。

CycleGAN的介绍与实现效果_第12张图片CycleGAN的介绍与实现效果_第13张图片CycleGAN的介绍与实现效果_第14张图片

 CycleGAN的介绍与实现效果_第15张图片CycleGAN的介绍与实现效果_第16张图片CycleGAN的介绍与实现效果_第17张图片

 上图顺序分别是真实图片,生成图片,和转换图片,可以看出效果还可以。一共训练了200轮跑了大概两天吧,记不太清了。这里给大家一个建议,主要看看循环损失的意义,在就是出了问题一定要找到原因,要不以后学起来可能会比较困难。如果有不会的地方、或者我写错的地方,可以私聊我一下,如果看见了,我可以回答一下。

 

 

 

 

 

 

 

你可能感兴趣的:(笔记,深度学习,pytorch,神经网络)