基于对比学习的不对称图像转换 http://arxiv.org/abs/2007.15651
作者:Taesung Park, Alexei A. Efros, Richard Zhang ,Jun-Yan Zhu
开源代码:https://github.com/taesungp/contrastive-unpaired-translation
在图像转换(image-to-image translation)的任务中,我们想要的是在保留输入图像的结构特征的基础上,加入目标域的外观特征。一个经典的任务就是把马转换成斑马,在保留输入的马的图像结构的同时,将纹路换成目标域(斑马)的纹路。目前主流的做法基本上都是基于CycleGAN方法的变种,利用对抗损失(adversarial loss)强化目标域的外观特征,使用循环一致性损失(cycle-consistency loss)来保证原始输入图像的结构不变。但是CycleGAN的假设非常严格,要求输入的图像域和目标域之间存在双射关系,这一点在其实是很难满足的。所以这篇论文提出了一个替代性方案,通过最大化输入输出图像块的互信息(mutual information),使用一个对比损失函数,infoNCE loss, 来学会一个encoder将对应的图像块之间相互联系起来,与其他的图像块分离,如此一来encoder可以专注于两个域之间共性的部分如形状,而忽略两个域之间的差异性部分如纹理。CUT这篇论文证明了以多层次,图像块的范式运用对比学习技术的有效性,并且发现从单张图像本身中提取负性图像块的效果要好于从整个数据集中其他的图像中提取,因此甚至可以在单张图像上实现图像转换。
如下图中,使用多层图像块的对比损失,最大化相对应的多层图像块之间的互信息,这样将生成器和Encoder相结合,取得对应输入图像的生成图像。
对称图像转换(pix2pix),使用对抗损失和重建损失形成输入和输出图像之间的映射。在非对称图像转换中,没有目标域的对应样本,循环一致损失成为事实上的标准做法(CycleGAN),通过学习一个从目标域到输入图像的映射,来检查是否输入图像被正确映射到了目标域。之后的做法大多是在循环一致损失的基础上完成的(如UNIT,MUNIT),在这个领域,循环一致损失主要在三个层面上使用,图像与图像之间,隐空间到图像,图像到隐空间。但是这些都基于输入域和目标域之间存在双射关系的严格假设,这一点当某个域的图像由相较于另个域更多的信息时就更难获得很好的效果。
为了避开双射的限制,一个替代的想法是输入图像中存在的关系,类似地也应该在生成的图像中体现,就比如同一张图内近似的图像块,在生成的图像中也应该有这样近似的图像块。TraVeLGAN, DistanceGAN and GcGAN通过预定义的距离函数保证共享相似的内容,或是使用triplet loss保存输入图像之间的向量计算,再或是计算输入图像之间的距离和生成图像的距离使之保持一致等等做法,绕开循环一致性损失的限制。但是这些方法要么是需要预定义一个距离函数,要么保存的关系是基于整个图像的。CUT的做法是通过最大化互信息的方法,学习一个输入输出图像块之间的相似性函数,避免了以上的方法的缺陷。
大多是图像转换工作都是使用的逐像素重建进行度量,这无法反映人类的感知习惯并且会导致生成图片非常模糊。因此可以定义一个高维信号的感知距离函数,这一点使用在ImageNet上预训练的VGG分类网络就可以实现 ,并且其在人类感知测试中取得了超过传统度量方法(SSIM and FSIM)的效果。但是这个方法没法适应其他的数据集,并且它也不是一个基于图像对的相似性度量。CUT以互信息作为约束,将图像本身中的负样本利用起来,可以适用于不同特定的输入输出域,从而避免了对相似性函数的预定义。
传统的无监督学习需要预先设计好的损失函数来衡量预测表现,新的方法通过最大化互信息绕开这个问题,使用噪声对比估计(noise contrastive estimation,NCE)来学习一个Encoder,将关联的信号拉近,并与数据集中的其他样本形成对比。信号可以是图像本身,也可以是下采样特征,相邻图像块等等。CUT首先将infoNCE loss应用到了条件图像生成领域。
首先要定义图像转换问题,图像输入域为 X ∈ R H × W × C \mathcal{X}\in\mathbb{R}^{H\times W\times C} X∈RH×W×C,而输出图像域为 Y ∈ R H × W × 3 \mathcal{Y}\in\mathbb{R}^{H\times W\times 3} Y∈RH×W×3,数据集为 X = { x ∈ X } X=\{x \in \mathcal{X}\} X={x∈X}, Y = { y ∈ Y } Y=\{y \in \mathcal{Y}\} Y={y∈Y}, 其中在CUT的方法中数据集可以只包含单张图像。
在CUT方法中,生成器被 G G G分解为两个部分, 先是一个Encoder再是一个decoder,这样生成输出图像 y ^ \hat y y^的过程变成了, y ^ = G ( z ) = G d e c ( G e n c ( x ) ) \hat y=G(z)=G_{dec}(G_{enc(x)}) y^=G(z)=Gdec(Genc(x)).
在GAN的图像生成部分,CUT仍然是使用GAN的对抗损失,来保证生成的图像能和目标域的图像尽可能相似,这部分的损失就是:
L ( G , D , X , Y ) = E y ∼ Y log D ( y ) + E x ∼ X log ( 1 − D ( G ( x ) ) ) \mathcal{L}(G,D,X,Y)=\mathbb{E}_{y\sim Y}\log D(y)+\mathbb{E}_{x\sim X}\log(1-D(G(x))) L(G,D,X,Y)=Ey∼YlogD(y)+Ex∼Xlog(1−D(G(x)))
在互信息最大化方面,采用noise contrastive estimation(NCE)框架。对比学习的问题有三个信号组成,query和正样本,负样本,要做的就是让query和正样本信号相关联和负样本形成对比。将query和正样本, 以及N个负样本,分别映射成K维向量 v , v + ∈ R K , v − ∈ R N × K v,\ v^+\in \mathbb{R}^K,\ v^-\in \mathbb{R}^{N\times K} v, v+∈RK, v−∈RN×K,并用 v n − ∈ R K v^-_n\in \mathbb{R}^K vn−∈RK 表示第n个负样本,将这些样本归一化至单位球中,防止空间扩张或坍缩。这样就形成了一个N+1的分类问题,交叉熵损失计算如下其中 τ \tau τ是比例超参,表示正样本被选中的概率。
ℓ ( v , v + , v − ) = − log [ exp ( v ⋅ v + / τ ) exp ( v ⋅ v + / τ ) + ∑ n = 1 N exp ( v ⋅ v − / τ ) ] \ell(v,v^+,v^-)=-\log\left[\frac{\exp(v\cdot v^+/\tau)}{\exp(v\cdot v^+/\tau)+\sum^N_{n=1}\exp(v\cdot v^-/\tau)}\right] ℓ(v,v+,v−)=−log[exp(v⋅v+/τ)+∑n=1Nexp(v⋅v−/τ)exp(v⋅v+/τ)]
无监督学习中用到对比学习,既有图像层次也有图像块层次,具体到CUT要解决的任务中,对于输入输出图像不仅整个图像应该有着同样的结构,对应的图像块之间也应该有相应的结构 。所以应该用多层次图像块(multilayer patch-based)的学习目标。通过Encoder G e n c G_{enc} Genc 编码特征层,其中不同层不同空间位置代表了不同的图像块,层数越深图像块越大。CUT选择了 L L L 层特征图,将其通过2层 M L P MLP MLP 网络 H l H_l Hl产生了一系列的特征 { z l } L = { H l ( G e n c l ( x ) } L \{z_l\}_L=\{H_l(G^l_{enc}(x)\}_L {zl}L={Hl(Gencl(x)}L ,其中 G e n c l G^l_{enc} Gencl 表示第 l l l 层输出特征。序列 l ∈ { 1 , 2 , 3 , . . . , L } , s ∈ { 1 , 2 , . . . , S l } l\in\{1,2,3,...,L\},\ s\in\{1,2,...,S_l\} l∈{1,2,3,...,L}, s∈{1,2,...,Sl}, 其中 S l S_l Sl 表示第 l l l 层有 S l S_l Sl 个空间位置。将对应特征记为 z l s ∈ R C l z^s_l\in\mathbb{R}^{C_l} zls∈RCl 其他特征标记为 z l S \ s ∈ R ( S l − 1 ) × C l z^{S\backslash s}_l\in\mathbb{R}^{(S_l-1)\times{C_l}} zlS\s∈R(Sl−1)×Cl ,其中KaTeX parse error: Undefined control sequence: \C at position 1: \̲C̲_l 是每层的通道数。 同样的将输出图像 y ^ \hat y y^ 也编码成 { z ^ l } L = { H l ( G e n c l ( x ) } L \{\hat z_l\}_L=\{H_l(G^l_{enc}(x)\}_L {z^l}L={Hl(Gencl(x)}L
CUT的目标是将输入输出对应位置的图像块进行匹配,同一张图像其他位置的图像块作为负样本,将损失记做PatchNCE loss:
L P a t c h N C E ( G , H , X ) = E x ∼ X ∑ l = 1 L ∑ s = 1 S l ℓ ( z ^ l s , z l s , z l S \ s ) \mathcal{L}_{PatchNCE}(G,H,X)=\mathbb{E}_{x\sim X}\sum^L_{l=1}\sum_{s=1}^{S_l}\ell(\hat z^s_l,z^s_l,z^{S\backslash s}_l) LPatchNCE(G,H,X)=Ex∼Xl=1∑Ls=1∑Slℓ(z^ls,zls,zlS\s)
当然也可以从数据集的其他图像中提取图像块做负样本记做 z ~ \tilde z z~,可以像MOCO一样用一个辅助的移动平均Encoder H ^ l \hat H_l H^l 和移动平均 M L P H ^ MLP\ \hat H MLP H^ 共同计算,维护一个负样本字典 Z − Z^- Z−。
L e x t e r n a l ( G , H , X ) = E x ∼ X , z ~ ∼ Z − ∑ l = 1 L ∑ s = 1 S l ℓ ( z ^ l s , z l s , z ~ l ) \mathcal{L}_{external}(G,H,X)=\mathbb{E}_{x\sim X,\tilde z\sim Z^-}\sum^L_{l=1}\sum_{s=1}^{S_l}\ell(\hat z^s_l,z^s_l,\tilde z_l) Lexternal(G,H,X)=Ex∼X,z~∼Z−l=1∑Ls=1∑Slℓ(z^ls,zls,z~l)
最终的目标函数,和CycleGAN一样也添加了一致损失(identity loss) L P a t c h N C E ( G , H , Y ) \mathcal{L}_{PatchNCE}(G,H,Y) LPatchNCE(G,H,Y),以使 E y ∼ Y ∥ G ( y ) − y ∥ 1 \mathbb{E}_{y\sim Y}\|G(y)-y\|_1 Ey∼Y∥G(y)−y∥1 尽量小避免生成器对产生的图片造成不必要的变化。所以总损失包含对抗损失,对比损失,一致损失三个部分。
L ( G , D , X , Y ) + λ X L P a t c h N C E ( G , H , X ) + λ Y L P a t c h N C E ( G , H , Y ) \mathcal{L}(G,D,X,Y)+\lambda_X\mathcal{L}_{PatchNCE}(G,H,X)+\lambda_Y\mathcal{L}_{PatchNCE}(G,H,Y) L(G,D,X,Y)+λXLPatchNCE(G,H,X)+λYLPatchNCE(G,H,Y)
当使用 λ X = 1 , λ Y = 1 \lambda_X=1,\ \lambda_Y =1 λX=1, λY=1 联合训练时称为CUT,当取 λ Y = 0 \lambda_Y=0 λY=0 时,作为补偿取 λ X = 10 \lambda_X=10 λX=10 时称为FastCUT, 可以被看做是更快更轻量级的CycleGAN。可以看出CUT所采用的损失函数组成部分不多,要求的超参也不多。
L ( G , D , X , Y ) + λ X L P a t c h N C E ( G , H , X ) + λ Y L P a t c h N C E ( G , H , Y ) \mathcal{L}(G,D,X,Y)+\lambda_X\mathcal{L}_{PatchNCE}(G,H,X)+\lambda_Y\mathcal{L}_{PatchNCE}(G,H,Y) L(G,D,X,Y)+λXLPatchNCE(G,H,X)+λYLPatchNCE(G,H,Y)
当使用 λ X = 1 , λ Y = 1 \lambda_X=1,\ \lambda_Y =1 λX=1, λY=1 联合训练时称为CUT,当取 λ Y = 0 \lambda_Y=0 λY=0 时,作为补偿取 λ X = 10 \lambda_X=10 λX=10 时称为FastCUT, 可以被看做是更快更轻量级的CycleGAN。可以看出CUT所采用的损失函数组成部分不多,要求的超参也不多。
在实验中,CUT从图像质量和发现对应关系的能力上进行评价,图像质量通过 F I D FID FID (Fr´echet Inception Distance) 进行度量,而后者使用生成器Encoder的第一个残差块,可视化输出对应图像块特征的相似性,并通过PCA可视化主成分,验证了Encoder学到了相似性函数。
综合来看,CUT这篇论文相较于其他非对称图像转换的论文,主要的创新点还是在于引入了对比学习的概念,将CycleGAN的循环一致性损失改换成对比损失,放松了对图像域要求存在双射关系的假设,因此可以用在单向的图像转换任务中去,并且在结构上更加轻量级,避免了CycleGAN额外的生成器和判别器,减少了计算花费。从消融实验中可以看出,CUT的最关键核心点在于基于最大化互信息的,使用输入图像本身的图像块,以及要使用多层Encoder获得不同层级的。
pix2pix:Image-to-Image Translation with Conditional Adversarial Networks
CycleGAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
MOCO:Momentum Contrast for Unsupervised Visual Representation Learning
infoNCE:Representation learning with contrastive predictive coding