CVPR 2021 TransGAN: Two Transformers Can Make One Strong GAN

动机
  1. GANs的训练不稳定性。

    生成性对抗性网络(GANs)在众多任务中获得了相当大的成功,包括图像合成,图像翻译和图像编辑。然而,由于GANs的训练不稳定性,即使目前已经投入很多努力来稳定GAN训练,彻底解决这个问题,需要做进一步研究。

  2. 改进GAN的另一条平行途径是检查它们的神经结构。

    经过对GAN的大量研究,发现当充当(生成器)主干时,流行的神经架构在所考虑的数据集上表现得相当好。他们的消融研究表明,在ResNet家族结构中应用的大多数变化导致样本质量的微乎其微的改善。然而,之后有研究将神经结构搜索(NAS)引入到GANs中,并表明,就像对其他计算机视觉任务一样,增强的主干设计对于进一步改进GANs也同样重要,进而提出了比标准ResNet拓扑结构更强的GAN体系结构。

  3. 以往的研究大多是使用卷积神经网络(CNNs)作为GAN的骨干。

    最初的GAN使用的是全连接网络,只能生成小图像。DCGAN是第一个使用CNN架构放大GAN的方法,该架构允许更高分辨率和更深生成模型的稳定训练。此后,在计算机视觉领域,几乎每一个成功的GAN都依赖于基于CNN的生成器和判别器。卷积对自然图像具有很强的归纳偏好,对当前GAN所获得的吸引人的视觉结果和丰富的多样性起到了至关重要的作用。

  4. 能建立一个完全没有卷积的强GAN吗?

    这不仅是一个出于求知欲的问题,而且也是一个具有实际意义的问题。从根本上说,卷积算子有一个局部感受野,因此CNN不能处理长时依赖关系,除非通过足够多的层。然而,这可能会导致特征分辨率和精细细节的损失,以及优化的困难。因此,传统的基于CNN的模型(包括常规的GAN)本质上不太适合于捕捉输入图像的“全局”统计数据,这可以通过计算机视觉中采用自我注意力和非局部操作的益处得到证实。

  5. 最近人们对transformers的强烈兴趣表明,它们有可能成为计算机视觉任务(如分类、检测和分段)的强大“通用”模型。transformer已经在自然语言处理(NLP)中盛行,并且最近,在各种视觉基准中开始表现得与他们的CNN competitors相当甚至更好。对于计算机视觉,transformer的魅力至少表现在两个方面:

    (1) 它具有很强的表征能力,没有人为定义的归纳偏好。相比较而言,CNN表现出对特征局部性的强烈偏向,以及由于在所有位置上共享滤波器权重而具有空间不变性;

    (2) transformer架构是通用的,概念上是简单的,并且有可能成为跨任务和领域的强大的“通用”模型。它可以摆脱在基于CNN的模型中常见的许多ad-hoc构建块。

方法
简介

不同于以往仅应用自我注意力或transformer编码器块结合基于CNN的生成模型的工作,首次利用纯transformer结构构建一个完全无卷积的GAN。然而,由于以前计算机视觉中所有纯粹基于transformer的模型都专注于分类和检测等判别任务,提出的GAN面临着几个令人生畏的挑战。首先,尽管直接应用于图像块序列的纯transformer架构可以在图像分类任务上表现得非常好,但不清楚相同的方式在生成图像时是否仍然有效,这对结构、颜色和纹理的空间一致性提出了很高的要求。现有的少数几个输出图像的transformer都一致地利用了基于CNN的部分编码器或卷积特征提取器。而且,即使给定了设计良好的基于CNN的架构,训练GANs也是出了名的不稳定和容易出现模型崩溃。训练视觉transformer也是众所周知的单调冗长的、繁重和数据饥渴。将两者混为一谈无疑会放大训练的挑战。

鉴于这些挑战,提出了一系列的改进和创新,以建立纯transformer式GAN架构,称为TransGAN。它包括一个 memory友好的基于transformer的生成器(它在降低嵌入维数的同时逐步提高特征分辨率)以及一个也是基于transformer的patch-level判别器。一个简单的选择可以直接堆叠来自原始像素输入的多个transformer块,但这将需要昂贵的内存和计算。因此,TransGAN从一个内存友好的基于transformer的生成器开始,通过逐步提高特征映射分辨率,同时降低每个阶段的嵌入维数。判别器也是基于transformer的,它将图像块(而不是像素)tokenize为输入,并在真图像和假图像之间进行分类。这种vanilla的TransGAN结构通过自注意力自然地继承了全局感受野的优势,但实际上却导致了生成能力的退化和视觉平滑性的损坏。为了缩小基于CNN的GAN之间的性能差距,采用数据增强(比标准GAN更多),具有自监督辅助损失的多任务协同训练策略,以及强调自然图像邻域平滑性的局部初始化自注意力,来进一步优化TransGA。
CVPR 2021 TransGAN: Two Transformers Can Make One Strong GAN_第1张图片

TransGAN

选择transformer编码器作为基本块,并尽量做出最小的改变。编码器由两部分组成。第一部分由多头自注意力模块构成,第二部分是具有GELU非线性的前馈MLP组成。在这两个组成部分之前应用层归一化。两个组成部分均采用残差连接。

内存友好的生成器

NLP中的transformer是将每个单词作为输入。然而,如果类似地通过堆叠transformer编码器以逐像素的方式生成图像,即使是低分辨率的图像(例如32×32)也会导致长序列(1024),然后导致甚至更爆炸性的自注意力的代价(序列长度的二次方)。为了避免这种令人望而却步的成本,启发自基于CNN的GAN的共同设计理念,即在多个阶段迭代提升分辨率。采用的策略是逐步增加输入序列,降低嵌入维数。

提出了一个 memory友好的基于transformer的生成器,它由多个阶段组成(对CIFAR-10默认值为3)。每个阶段堆叠多个编码器块(默认为5、2和2)。分阶段,逐步提高特征图分辨率,直到它满足目标分辨率HT×WT。具体来说,生成器将随机噪声作为输入,并将其通过多层感知器(MLP)传递到长度为H×W×C的向量。向量被重新调整大小为H×W分辨率的特征映射(默认情况下H=W=8),每个点都有一个C维嵌入。该“特征映射”随后被处理为长度为64的C维token序列,并与可学习的位置编码相结合。

与BERT类似,transformer编码器将嵌入token作为输入,递归计算每个token之间的对应关系。为了合成更高分辨率的图像,在每个阶段之后插入一个上采样模块,由一个reshaping和pixelshuffle 模块组成。上采样模块首先将一维token嵌入序列重构为一个二维特征映射x0∈RH×W×C,然后采用pixelshuffle模块对其分辨率进行上采样,对嵌入维数进行下采样,得到输出x0∈R2H×2W×C/4。然后,将2D特征映射x0 reshape为嵌入token的1D序列,其中token数变为4HW,嵌入维数为C/4。因此,在每个阶段,分辨率(H,W)变大了2倍,而嵌入的维度C减小到输入的四分之一。这种权衡减少了内存和计算的增大。重复多个阶段,直到分辨率达到(HT,WT)为止,然后将嵌入维度投影为3,并得到RGB图像Y(维度为HT×WT×3)。

判别器的tokenized输入

与需要精确合成每个像素的生成器不同,判别器只需要区分真/假图像。这允许将输入图像语义化tokenize为更粗糙的patch-level。判别器将图像块作为输入。将输入图像Y∈RH×W×3分割成8×8的图像块,每个图像块可以看作一个“单词”。然后通过一个线性flatten层将8×8的图像块转化为token嵌入的1D序列,token数N=8×8=64,嵌入维数为C。之后,加入可学习的位置编码,并在1D序列的开始处附加一个[cls]标记。在通过transformer编码器后,分类头只获取[cls] token来输出真/假预测。

基于transformer的GAN的性能评估

为了提高基于transformer的生成器G和判别器D的性能,参考了一个当前最先进的GAN,即AutoGAN,它具有卷积G和D。研究了四种组合:i)AutoGAN G+AutoGAN D(即原始的AutoGAN);ii)Transformer G+AutoGAN D;iii)AUTOGAN G+Transformer D;iv)Transformer G+Transformer D(即vanilla TransGAN)。

transformer G在各阶段有{5,2,2}个编码器块,Transformer D只有一个阶段,该阶段有7个编码器块。对于所有模型,在CIFAR-10上对其进行训练,以评估Inception Score(IS)和FID。尽最大努力调整超参数以达到最佳性能。研究发现,Transformer G具有很强的能力:在与成熟的AutoGAN D进行训练时,其性能已经与原AutoGAN不相上下。在生成器中使用transformer是成功的。而Transformer D似乎是一个劣质的“竞争对手”,无法促进AutoGAN G生成更好的结果。用Transformer G代替AutoGAN G后,结果略有改善,更有利于G和D结构的对称性。然而,使用卷积D,性能仍然落后了很多。虽然还不是一个纯transformer,但考虑到在大多数GAN应用中,判别器在训练后被丢弃,而只有生成器是被留作测试使用。因此,Transformer G+AutoGAN D的有希望的结果已经具有实践关联了。transformer保留以供测试使用。如果目标只是简单的获得一个基于Transformer的G,那么通过Transformer G+AutoGAN D,可以充分的实现这个目标。然而本论文终极目标是让GAN完全没有卷积,因此研究还得继续。

数据增强(DA)

无论使用CNN还是基于transformer的G,似乎D都没有得到很好的训练。已知基于transformer的分类器存在高度地数据饥渴问题,这是由于人为设计的偏好移除导致:它们不如CNN,直到使用大得多的外部数据进行预训练。为了清除这一障碍,数据增强成为一大利器,揭示了不同类型的强数据增强可以促进对视觉transformer进行数据高效的训练。传统上,与训练图像分类器相反,训练GAN几乎不考虑数据增强。最近,在“小样本”制度下训练GAN的兴趣激增,旨在使用精心制作的数据增强技术,将最先进的GAN结果与数量级更少的真实图像进行匹配。

在一个不同的环境中,比较数据增强对CNN和基于transformer的GANs的影响。使用CIFAR-10的整个训练集,并将TransGAN与三种最先进的基于CNN的GAN进行比较:WGAN-GP,AutoGAN和StyleGAN v2。数据增强方法采用DiffAug。对于三个基于CNN的GAN,在全数据状态下,数据增强的性能增益似乎减小。只有最大的模型StyleGAN-V2似乎在IS和FID两个方面都有明显的提高。与此形成鲜明对比的是,在同一训练集上训练的TransGAN的增幅是惊人地大:分别从6.95提高到8.15和从41.41降为19.85。这再次证实了基于transformer的架构比CNN更需要数据,并且在很大程度上可以通过更强的数据增强来帮助实现。

与自我监督辅助任务的协同训练(MT-CT)

NLP领域的transformer受益于多个预训练任务。有趣的是,添加一个自监督的辅助任务(例如,旋转预测)之前也被发现可以稳定GAN训练。这使得将自监督辅助协同训练纳入TransGAN成为一种自然的想法,这或许有助于它捕获更多的图像先验。具体来说,除了GAN损失之外,还构造了一个超分辨率的辅助任务。将现有的真实图像视为高分辨率,然后对其进行下采样以获得低分辨率的对应物。生成器的损耗加上一个辅助项λ *LSR,其中LSR是均方误差(MSE)的损失,系数λ按照经验设定为50。多任务协同训练(MT-CT)将TransGAN从8.15的IS和19.85的FID分别改进到8.20的IS和19.12的FID。

局部感知的自注意力初始化(Local Init.)

CNN架构具有自然图像平滑性的内置先验,这被认为有助于自然图像生成。这是transformer架构所缺乏的,它的特征具有充分的学习灵活性。然而,transformer仍然倾向于从图像中学习卷积结构。因此,一个有意义的问题就出现了,是否能够在保持transformer灵活性的同时,有效地对图像的归纳偏好进行编码。在完全不改变纯transformer结构的情况下,通过适当地热启动自注意力,来追求这一点。

由于局部自注意力在早期训练阶段最有帮助,但在后期训练阶段可能会受到损害最终可达到的性能。为了注入这个特定的先验,引入了局部感知的自注意力初始化。引入了一个掩码,通过该掩码,每个查询只允许与其未被“掩码”的局部邻居交互。不同于以往的方法,在训练过程中,逐渐减少掩码,直到把它缩小到最终自注意力完全全局化。(在实现方面,控制允许的交互式邻域token的窗口大小。对于第0-20个阶段,窗口大小为8,然后对于第20-30个阶段,窗口大小为10,对于第30-40个阶段,窗口大小为12,对于第40-50个阶段,窗口大小为14,然后之后都是完整图像。)可以认为这种局部感知初始化是一种正则化器,它用于早期训练动态,然后逐渐消失。它将强制TransGAN学习图像生成,首先优先考虑局部邻域(提供“必要的细节”),然后更广泛地利用非局部的交互(可能提供更多“更精细的细节”和噪声)。

扩展到大型模型

所有以前的训练技术都有助于更好、更稳定的TransGAN,仅由基于transformer的G和D组成。配备了所有这些技术(DA、MT-CT和Local Init.)后,可以scale TransGAN,看看可以从更大的模型中获得多大的收益。

首先将基于transformer的G的嵌入维度增大,从(默认)384到512,然后到7682,并将得到的模型分别表示为TransGAN-S、TransGAN-M和TransGAN-L。这就带来了在IS(提升为0.28)和FID(提升为4.08)上的持续而明显的提高,而无需任何额外的超参数调优。然后增加TransGAN-L顶部的深度(transformer编码器块的数量)。原transformer G在每个阶段中有{5,2,2}个编码器块,增加编码器块的数目到{5,4,2},并且嵌入维度也增加到1024,从TransGAN-L获得TransGAN-XL。尽管如此,IS和FID都有提升,而且FID指标又减少了2.57。实验发现放大G可以显著提高性能,而放大D的影响似乎可以忽略不计。因此,默认增加G大小,保持D不变。

实验

TransGAN可以有效地提高更大的模型和高分辨率的图像数据集。与当前基于卷积骨干的SOTA GAN相比,提出的最好的架构实现了极具竞争力的性能。具体来说,TransGAN在STL-10上设置了新的SOTA IS得分10.10分和FID得分25.32分。它也在CIFAR-10上取得了IS 8.63的分数和FID 11.89的分数,在CelebA 64×64上取得了FID 12.23的分数。这些成果是颇具竞争力的。

贡献

模型架构:使用纯transformer和无卷积来构建第一个GAN。为了避免过多的内存开销,创建了一个内存友好的生成器和一个patch-level判别器,这两个都是基于transformer的,没有一些花里胡哨的trick。跨GAN可以有效地应用到更大的模型。

训练技术:研究了一系列技术来更好地训练TransGAN,包括数据增强、具有自监督辅助损失的生成器的多任务协同训练以及自注意力的局部初始化。本文做了大量的消融研究、讨论和见解。它们都不需要任何架构更改。

性能:TransGAN实现了与当前最先进的基于CNN的GAN相比较的极具竞争力的性能。

小结

一个GAN由发生器G和判别器D组成,本论文从用transformer代替G或D开始,了解设计灵敏度;然后替换这两个,并优化提出的设计,以提高内存效率。在G和D都是transformer的vanilla TransGAN的基础上,逐步引入了一系列训练技术来弥补其不足,包括数据增强、辅助任务协同训练、自注意力注入局部性等。有了这些帮助,TransGAN可以放大到更深/更宽的模型,并生成高质量的图像。

你可能感兴趣的:(CVPR,2021)