lesson7 part3 GAN

回到图像修复 [52:01]

lesson7-superres-gan.ipynb

要创建可以把差图片变成好图片的模型,我们需要有包含好图片和差图片的数据集。最简单的方式,就是找到好图片,再把它们变差。


lesson7 part3 GAN_第1张图片
image.png

把它们变差的方式是创建一个叫crappify()的函数,里面包含了你把图片变差的逻辑。我的是这样的,你可以自己写一个:

打开好图片

  • 调整大小,变成小的96x96分辨率,使用双线性插值
  • 然后取一个10到70随机数
  • 把这个数画到图片随机的位置
  • 然后用这个随机数作为JPEG画质保存图片
  • 如果JPEG画质是10,图片看起来就是垃圾,如果是质量数是70则说明质量不错。如


    lesson7 part3 GAN_第2张图片
    image

    可以看下这个(最后一行),这里有个数字。它是翻转变形后的。你不会总是看到这样的数字,这是我们随机添加的,很多时候,它是没有的。

我们要演示怎样把这样有文字的、画质非常差的、故意制作的图片变成右边的这样高质量的图。我使用是牛津的宠物数据集。我们在第一课用的那个。其它的画质都不如这些猫和狗的图片。

parallel并行 [53:48]

crappify()函数的处理可能需要花挺长时间,但fastai有一个叫parallel的函数,如果你给parallel()传入一个函数名,然后还有一连串东西来运行这个函数,这样可以运行的非常快。
lesson7 part3 GAN_第3张图片
image.png

这次作业里一件有意思的事就是写这个函数,试一试想想怎么写一个有趣的残次化函数,可以用来做你想做的事。如果你想为黑白图片着色,你可以用残次化函数来把图像变成黑白的。如果你想对图像做大块地切割,并把它们替换成幻觉图片,你可以在这里加一个黑箱。
lesson7 part3 GAN_第4张图片
image.png

如果你想处理旧家庭照片扫描图,那种已经折起来,起皱了的图片。你可以尝试找到方法,来给图片增加落灰,起皱了的折痕,等等。任何你没有添加在残次化函数里的东西,你的模型不会去学习怎样修复它。因为每次它在你的输入和输出照片上看到的都是一样的, 所以它认为这不是它需要修复的东西。

现在我们想创建这样一个模型,可以输入左边这样的照片,输出右边这样的照片。
lesson7 part3 GAN_第5张图片
image

显然,我们要用U-Net,因为我们已经知道U-Net可以做这样的事。我们把数据传到U-Net里。

lesson7 part3 GAN_第6张图片
image.png

我们的data就是这两个文件夹里的这些图片的文件名,

lesson7 part3 GAN_第7张图片
image.png

做一些数据转化,数据堆归一化,或者使用imagenet_stats。因为我们将用预处理模型。为什么我们用预训练模型?因为如果我们想去掉这个46,你需要知道这可能是什么,你需要知道这是一个什么的图片。不然,怎样知道它原本应该是什么样子?所以我们要用预训练模型,它知道如何处理这些东西。

我们用这些数据创建U-Net,架构是ResNet34。这三个东西(blur, norm_type, self_attention)是重要的、有意义的、有用的,
lesson7 part3 GAN_第8张图片
image.png

但我要把它们放在课程第二部分讲,如果你要用U-net解决这样的问题的话,这三者要一直保留。

这部分东西我叫它生成器,它将生成模型。这不是它们的正式定义,但基本意思是这样的。我们输出的是实物,这个例子里就是一个图像,而不只是数字。我们将生成一个生成学习器(generator learner),也就是这个U-Net learner,然后我们可以fit,我们使用了Mse损失函数,即均方差,也就是实际像素值和我们的预测的像素值间的均方差。MSE 损失通常要两个向量。这里,我们有两个图片,我们有一个MSE Loss Flat版本(MSC扁平损失),也就是将那些图像扁平化成一个长向量。’没有理由不这样做,即使你只有一个向量,它也可以用,如果你没有向量,它也会工作得很好。

lesson7 part3 GAN_第9张图片
image.png

所以关于像素值,均方误差已经下降到了0.05,用了1分35秒。这个结果还可以接受。像fastai里的所有东西一样表现不错。因为我们默认做迁移学习,当你创建这个时,

image.png

它会冻结预训练部分。U-Net的预处理部分是这个下采样部分。

lesson7 part3 GAN_第10张图片
image.png

就是用了ResNet在的地方。

lesson7 part3 GAN_第11张图片
image.png

所以我给这部分解冻,再训练一下,看看结果!用4分钟的训练,我们得到了可以完美去除数字的模型。


lesson7 part3 GAN_第12张图片
image

它的上采样做得不好,但也挺好了。有时它在去除数字时,它可能会留下一点残迹。但也已经做得很好的了。所以如果我们想做的是,去除水印,这已经结束了。

但是我们还没有完成,因为我们想让中间的图更像右边的图。怎样做到呢?我们没有做到这个,是因为我们的损失函数没有体现我们想要的是什么样的。因为实际上,中间图片和右边图片的像素均方差非常小。大多数像素都非常接近正确的颜色了,但我们还缺了枕头的纹理,以及几乎整个眼球的部分。我们还缺少了毛发质感的部分。所以我们想要一些能比MSE效果更好的损失函数。比如说,这是不是一个优质图片。

生成对抗网络(Generative Adversarial Network) [59:23]

这有一个比较普遍的,回答这个问题的方法。它叫做生成对抗网络(Generative Adversarial Network),也就是GAN。GAN用一个损失函数来解决这个问题,这个函数实际上调用了另外一个模型。


lesson7 part3 GAN_第13张图片
image

我们拿到了差的图片(crappy image),而我们已经建立了一个生成器,然后它产生了这样的预测(中间图片)。我们把一个高分辨率的照片(右边图片),和用基于像素的均方误差指标生成的预测来比较。

我们也可以训练另外一个模型,我们可以叫它discriminator(辨别器)或者critic(鉴别器),这都是相同的东西。我会叫它critic。我们构建一个二分类模型,比较所有生成图像和相应的真实的高分辨率图像,并学习分类。区分生成和真实的图像。当你看一些图片,然后问“嘿,你怎么看,这是一个高分辨率的猫还是一个生成的猫?这个呢?是一个高分辨率的猫还是一个生成的猫?”这就是一个普通的标准二分类交叉熵分类器。我们已经知道怎样做这个二分类分类器了。如果我们有这两者之一(Discriminator/Critic),我们可以训练微调生成器,


lesson7 part3 GAN_第14张图片
image.png

不再使用像素均方差做损失度,损失度可以是我们骗过判别器的程度作为损失,能否生成出判别器认为是真实的图片。

这是一个很好的方案,因为如果做到了这个,如果损失函数是“我能不能骗过critic”,它会学习生成critic分辨不出真假的图片。我们可以像这样训练几个批次。但是critic表现并不是那么好,因为要分辨它们不是很难。这些图片太差了,所以要分辨出来很简单。所以,在用这个critic做损失函数,训练了一段时间后,这个生成器变得很擅长骗过critic。现在我们要停止训练generator,我们要用这些新生成的图片再训练critic。现在这个generator更好了,对critic来说,要判断哪个是真的哪个是假的变得更难了。我们要多训练一下判别器。一旦完成,这个critic现在又很擅长识别出生成的图片和原始图片的区别,我们要回过头来,再用这个更好的critic作为损失函数微调generator。

我们这样来回做。这就是一个GAN。这是我们的GAN版本。我不知道有没有人写过这个,我们做了一个新版本的GAN,它和原始的GAN很像,但是我们有个精妙的技巧,我们用了预训练的generator和critic。GAN经常上新闻。是很是时尚的工具,如果你见过它们,你可能听说过训练起来很难。但我们发现,最难的是在开始。如果你没有预训练的generator,也没有预训练的critic,那基本是盲人骑瞎马。generator要生成能骗过critic的东西,但critic什么都不知道,所以没有什么事可做。然后critic试着判断是生成的图片是不是真的,这很明显,所以它直接就做了决断。所以它们都没有进步。等到他们终于上道了,进展就很迅速了。

所以如果你找到一种无须GAN来生成图片的方法,比如基于像素的均方误差损失,不用GAN模型而去做判别,比如在第一代生成器上预测,你就可以有很大的进展。

创建一个critic/discriminator [1:04:04]

我们来创建critic。要创建一个完全标准的fastai分类模型,我们需要两个文件夹,一个放高分辨率的图片,一个放生成的图片。我们已经有了存高分辨率图片的文件夹,我们只需要保存生成的图片。

lesson7 part3 GAN_第15张图片
image.png

这是做这个的一小段代码。我们要创建一个叫image_gen的目录,赋值给path_gen的变量里。一个叫save_preds的小函数,它接收一个data loader。我们取出所有的文件名。对一个item list来说,如果它是一个image item list,.items里存的是文件名。这个data loader的dataset里是文件名。现在,我们看看data loader里的每一个batch的数据,我们取出一个Batch的预测,reconstruct=True代表它会为batch里的每一个东西创建fastai图片对象。我们遍历每个预测值,保存它们。我们用和原始文件一样的名字,但是会把它放到新目录里。

就是这样。这样保存预测结果。可以看到,我不仅用fastai里现成的东西,也给你们看怎样写自己的东西。通常,这不需要多少代码。如果你学习课程第二部分,里面很多地方就像这样里的一样,会教你怎样用fastai库里的东西,当然,这里是怎样写库里的代码。越来越多地,我们会学写自己的代码。

好了,保存了这些预测值。我们来在第一个文件上执行 PIL.Image.open,它显示在这里。这是我们生成的一个图片样例。


lesson7 part3 GAN_第16张图片
image.png

现在我按照常规训练critic。重启Jupyter notebook来释放内存很麻烦。如果你知道是什么占用了大量的GPU,你可以直接把它设成None,这是一个简单的方法。


image.png

我们对这个learner做了这样的处理,然后运行gc.collect,这会让Python回收在用的内存。做完之后,内存就正常了。你可以使用所有的GPU内存了。
如果你用nvidia-smi查看内存,你看不出它被释放了,因为PyTorch还占用着它们做缓存,但是这些内存已经可以用了。这样,你不用重启notebook了。
lesson7 part3 GAN_第17张图片
image.png

我们要创建一个critic。和以前一样,它只是一个普通文件夹中的image item list,这个classes是image_gen 和 images。我们要做一个随机的分割得到一个验证集,因为我们想知道critic在验证集上表现如何。我们像之前一样用label_from_folder,做一些tranform,data bunch,normalize。这样我们得到了一个标准的分类器。这是里面的一些图片:


lesson7 part3 GAN_第18张图片
image

这是一个真的图片、真的图片、生成图片、生成图片等等,我们要区分出真伪。

我们还是像以前一样要使用binary cross entropy。但是,这里我们不再使用ResNet。在课程第二部分,我们再详细讲原因。这里需要说的是当你再次回顾这一点时,需要特别小心,generator和cirtic不能都用相同的方向推进,比如让权重增长到失去控制。我们需要用一种叫spectral normalization的东西让GAN正常运行,我们会在课程第二部分学习这个。

不管怎样,如果你用gan_critic,fastai会给你一个适合GAN的二元分类器。我怀疑我们可以在这里用一个ResNet,我们要创建一个带spectral norm的预训练ResNet。


lesson7 part3 GAN_第19张图片
image.png

希望快点就能做到。我们会拭目以待。但是现在,最好的方法是用gan_critic。一个GAN critic在计算损失时,会用一种和取均值略有不同的方式,所以现在在做GAN时,你要用AdaptiveLoss封装你的损失函数。还是一样,你会在课程第二部分看到细节。现在,只要知道这是你要做的,它很管用。

除了那个有点奇怪的损失函数,和有点奇怪的架构之外,其它的东西都是一样的。我们可以称之为我们创建的判别器,因为我们用了略微不同的架构和损失函数,metric也略有不同。这是和GAN版本准确率等价的critics。然后我们可以训练它,你可以看到它区分差的图片和好图片的准确率是98%。也就是在辨认残次图片和好图片上的准确率。但确实我们没看到数字,因为这是生成图像, generator已经知道怎样去掉上面的数字。


lesson7 part3 GAN_第20张图片
image.png

完成 GAN [1:09:52]

让我们来结束这个游戏吧。我们已经预训练了生成器,现在需要开始乒乓乒乓...,每个模型都训练一点。每个模型上训练分配的时间以及应该使用的学习率,依然没有清晰明确的取值办法,更多依赖于经验。我们创建了一个GANLearner

image.png

需要传入你的generator和critic对象,我们在这里就直接加载了。
lesson7 part3 GAN_第21张图片
image.png
就是我们刚刚训练好的。
然后你运行learn.fit时,会自动计算出训练生成网络的时间以及何时切换去训练判别器。它会来回地切换。

image.png

关于这里的权重(weights_gen=(1.,50.))。我们不仅用critic做损失函数。如果我们只用critic做损失函数,GAN会很擅长创建看起来真实的图片,但它这些图片与原始图片毫无关系。其实我们把像素损失和critic损失加到一起,这两个损失的尺度不同,我们应把像素损失乘上一个50到300的数,这样一个范围通常是有效的。

另外,关于GAN的其他方面,GAN在训练中不应使用学习率动量优化算法。基于动量的训练不太可行。因为你不停在generator和critic间切换,很难用上动量训练。可能有使用动量的方式,但没见过有人这样做过。所以,当你创建Adam优化器时,这个动量的参数(betas=(0.,...)),你要设成0。

所以只要你用了GAN就使用这些超参数,应该就行的通,这就是GANLearner的功能。


image.png

你可以执行fit,它会训练一会儿。


lesson7 part3 GAN_第22张图片
image.png

lesson7 part3 GAN_第23张图片
image

GAN里一个难点是这些损失度没有意义。你不能期望随着generator变好这些值会下降,因为随着generator变好,对critic来说,任务越来越难,随着critic变好,对generator来说越来越难。这些损失值会保持不变。很难知道它们做得怎么样,这是训练GAN时一个困难的地方。要知道它们做得怎么样的唯一的方式是时不时地看看这些实际的结果。不过我没有,如果你在这里运行 show_img=True
image.png

它会在每个epoch后打印出一个样本。我们没有把它放到notebook里,这对repo来说太大了,但是你可以试试。我把结果放到了最后,就是这里。

lesson7 part3 GAN_第24张图片
image.png

lesson7 part3 GAN_第25张图片
image

应该说结果很漂亮。已经知道怎样去掉这些数字,但现在我们已经没法知道这里原本是什么。
lesson7 part3 GAN_第26张图片
image.png

它确实很漂亮地锐化了把这只小猫。但也有问题,这里有一些奇怪的噪声。确实比原来的糟糕的图片好多了。
lesson7 part3 GAN_第27张图片
image.png

要把左边的图变成高分辨率是一个困难的任务。但有一些明显的问题。像这里(第三行),这里应该是眼球,但它们没有。
lesson7 part3 GAN_第28张图片
image.png

为什么?因为我们的critic不知道眼球。即使它们知道,它也不知道眼球非常重要。我们关心眼睛。当我们看到一个没有眼睛的猫,它一点也不可爱。但判别器不知道这是一个重要的特征。特别是因为判别器不是一个预训练的网络,所以我有点怀疑如果用预训练的网络替代判别器,一个在imageNet上预训练的网络,同时也兼容GAN,这样的网络效果可能会更好。不过这种做法显然也会有缺点。课间休息后,我会演示怎样找到猫的眼球。

提问: 对什么样的问题,你不使用U-Net? [1:14:48]

U-Net用于你的输出的大小和输入的大小接近的时,而且某种程度上一致时,交叉连接是没有意义的。如果那个程度的空间分辨率对于输出,没有必要。所以任何一种生成模型,比如图像分割,就是一种生成模型,它生成的图片,是原本物体的遮罩(mask)。所以几乎任何你想要的输出的分辨率和保真度和输入相当,这种东西对于分类器没有意义。在分类器里面,你只想要下采样路径,因为最后你只想要一个数字,代表是猫还是狗或者其他宠物。

Wasserstein GAN [1:15:59]

在结束GAN之前,讲下这里有一个你们可能有兴趣读的notebook lesson7-wgan.ipynb。几年前GAN刚出现时,人们一般用它凭空创建一些图片,我认为这没有什么用处。但是我想这是一个好的研究练习。所以我们实现了这个WGAN paper,它第一次做出了一些成果,也不难,你会看到如何用fastai库实现它,这很有趣。因为我们使用的是LSUN_BEDROOMS数据集,我们放在了URL里。你尅看到很多卧室,很多很多卧室。使用的方法,就是Silva写的方法。我们在这个例子中使用的方法,只是想创造一个卧室,所以我们实际上做的是,给生成器的输入,不是处理过的图片,而是随机噪声。生成器的任务是,把随机噪声变成判别器无法区分输出的图片和真实的卧室,我们没有做任何预训练,或其他加速和简化的东西,就是一个很传统的方法。但可以看到,还是用了GANLearner,实际上是WGAN版本,一种老式的方法,只需要循规蹈矩输入数据给生成器和判别器,然后调用fit(),就可以看到图片显示出来了。在第一个epoch后,以及两三个轮次后,都没有很好的卧室图片。可以看到,在早期,这种GAN做不出什么大不了的东西,但是最终训练了几个小时后,出来了一些像卧室的图片。所以这是个你可以捣鼓的notebook,很有意思。

你可能感兴趣的:(lesson7 part3 GAN)