用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析

韩国游戏公司NCSOFT最近开源了本算法的代码。

这篇论文的全名为《U-GAT-IT: UNSUPERVISED GENERATIVE ATTENTIONAL NETWORKS WITH ADAPTIVE LAYERINSTANCE NORMALIZATION FOR IMAGE-TO-IMAGE TRANSLATION》,这个算法做了一件非常有趣的事,把输入的真实人脸头像转换为二次元风格。

这是TensorFlow版本,曾经登上趋势榜第一 (现在变成了第三):https://github.com/taki0112/UGATIT

这是PyTorch版本github地址:https://github.com/znxlwm/UGATIT-pytorch

这是论文:https://arxiv.org/abs/1907.10830

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第1张图片

当然此算法并不是只能用在做人脸的风格迁移上,所有不同域之间的风格迁移都是可以做的,比如马转斑马。那么我们就来详细看看这篇论文吧。

这篇论文主要提出的三个创新点:

1、提出了一种新的无监督的(不需要成对数据)算法,带有一个注意力模块和一个新的标准化方法(作者命名为AdaLIN)。

2、其中这个注意力模块带有一个辅助的分类器,帮助模型更好地将源域迁移到目标域。

3、AdaLIN方法帮助模型灵活控制图片的形状和纹理,而不需要修改网络结构和超参数。

 

我们先来看一下完整的网络结构:

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第2张图片

上面是生成器,下面是判别器。

一、先看生成器:

首先学习两个概念:

global average pooling:将某一个卷积层的特征图进行整张图的一个均值池化,形成一个特征点,将这些特征点组成特征向量。举个例子,10个6*6的特征图,global average pooling是将每一张特征图计算所有像素点的均值,输出一个数据值,这样10 个特征图就会输出10个数据点,将这些数据点组成一个1*10的向量的话,就成为一个特征向量,就可以送入到softmax的分类中计算了。

CAM:https://blog.csdn.net/qq_30159015/article/details/79765520

生成器的目的就是为了将source图片,转换成target图片。source图片进来后,先经过一个降采样,之后经过encoder,得到此时的特征为Es,设Es有n个feature map(核)。将Es进行global average pooling处理,得到一个n维的向量,送入辅助的分类器ηs(分类器用于分类source和target)学习权重w,则w也为n维(作者在此受启发于CAM)。

公式为:  其中k为第k个feature map,i j为激活值的位置,δ为sigmoid激活函数。

此时我们就得到了w,利用w和Es,我们可以计算出a,公式为:

得到a后,再将a做一个AdaLIN的标准化处理,这个AdaLIN由作者提出,也是本文的核心创新点之一。

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第3张图片

其中μI和μL、δI和δL分别为channel-wise和layer-wise的均值和方差。什么叫channel-wise和layer-wise呢?顾名思义,channelwise就是对每一个feature map做处理,而layerwise则是对每一层去计算均值方差。其中y和β由全连接层生成,τ是学习率,Δp是由网络优化器得到的梯度,p是限制到[0,1]之间的值。因此,当p接近1的时候,instance normalization更重要,当p接近0的时候,LN更重要。作者在residual blocks初始化p为1,在up-sampling blocks初始化p为0,这两个blocks的位置请看图。

当然了,这么说的话可能有些同学分不清LN、IN这些标准化概念,具体可以看看我的另一篇博客:https://blog.csdn.net/wenqiwenqi123/article/details/105073639

在经过AdaLIN后,再进行上采样,得到生成的假的target图。

 

二、再看判别器

判别器的大致原理与生成器差不多,不同的是ηDt(辅助分类器)和Dt(x)(整个判别器)的训练目的都是为了区分输入的图片究竟是真实的还是生成器生成的。整体流程如下:

将真实的target图和生成器生成的假target图输入网络,经过encoder,得到E。再做一个global average pooling,送入辅助分类器ηDt,学习得到权重w。根据w和E,生成a。再用a经过sigmoid激活函数,作二分类。

 

三、损失函数:

本算法有四个损失函数:

1、对抗损失

这个损失自然是每一种GAN算法都有的,作者采用了Least Squares GAN的对抗损失:

2、循环损失:

这个损失可以参考下我的另一篇博客,cycleGAN的经典损失:https://blog.csdn.net/wenqiwenqi123/article/details/105123491

3、一致损失:

为了让输入的图和输出的图的分布相似,需要用一致损失来约束。也就是说输入一张target图,经过s->t的生成器,这张图不应该有太大变化。

4、CAM损失:

辅助分类器的分类损失:

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第4张图片

公式五是生成器的η,目的是区分source和target的图片。 公式六是判别器的η,目的是区分target和生成器生成的假target图。

因此,loss函数的总公式为:

 

四、实验

作者还是做了相当充分的实验的,各种消融实验,具体实验是做什么的请看图下方英文吧:

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第5张图片

 

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第6张图片

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第7张图片

值得一提的是作者还做了定量和定性实验:

定性实验由135个人,主观判断:

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第8张图片

 

定量试验他们使用了最近提出的KID,KID计算了生成的图和真实图的由inception网络提取出的特征表示的MMD距离。KID越小代表算法越好。

用小姐姐自拍,生成二次元萌妹子——《U-GAT-IT》人脸动漫化论文解析_第9张图片

 

你可能感兴趣的:(深度学习)