目录
- 一、原文摘要
- 二、介绍
- 三、为什么提出TransGAN?
- 四、主要框架
- 4.1、生成器
- 4.2、鉴别器
- 4.3、Self-Attention的一种变体:Grid Self-Attention
- 五、改进性策略
- 5.1、数据增强
- 5.2、相对位置编码
- 5.3、修正后的归一化
- 六、实验
- 6.1、数据集
- 6.2、实验设置
- 6.3、实验结果
- 6.4、消融实验
- 6.5、实验消耗
TransGAN是UT-Austin、加州大学、 IBM研究院的华人博士生构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。该论文已被NeruIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年12月。
该文章旨在仅使用Transformer网络设计GAN。Can we build a strong GAN completely free of convolutions?
论文地址:https://arxiv.org/abs/2102.07074
代码地址:https://github.com/VITA-Group/TransGAN
本博客是精读这篇论文的报告,包含一些个人理解、知识拓展和总结。
最近,人们对Transformer产生了爆炸性的兴趣,这表明Transformer有可能成为计算机视觉任务(如分类、检测和分割)的强大“通用”模型。虽然这些尝试主要研究区分模型,但我们探索了一些更为困难的视觉任务,例如生成性对抗网络(GAN)。
我们的目标是进行第一次试验性研究,仅使用纯Transformer架构,构建完全没有卷积的GAN,我们的vanilla-GAN架构被称为TransGAN,包括一个基于内存友好的Transformer的生成器,该生成器可逐渐提高特征分辨率,并相应地包含一个多尺度鉴别器,可同时捕获语义上下文和低级纹理。
在此基础上,我们引入了新的网格自关注模块,以进一步缓解内存瓶颈,从而将TransGAN扩展到高分辨率生成。我们还开发了一个独特的训练配方,包括一系列可以缓解TransGAN训练不稳定性问题的技术,如数据增强、修改的标准化和相对位置编码。
与目前最先进的使用卷积主干的GANs相比,我们的架构实现了极具竞争力的性能。TransGAN能够生成具有高保真度和合理纹理细节的各种视觉示例。此外,通过可视化训练动态,我们深入研究了基于Transformer的生成模型,以了解它们的行为与卷积模型的区别。
而本文主要创新点如下:
最初的transformer是为NLP设计的,在NLP中,多头自我注意层和前向反馈网络层层被堆叠起来,以捕捉单词之间的长期相关性,最近,Transformer在图像生成方面也有进展,通过替换CNN的某些组件,将Transformer模块结合到图像生成模型中,然而其CNN的整体架构仍然存在(包括用于发生器的CNN编码器/解码器,以及完全基于CNN的鉴别器)。
如果以逐个像素作为输入,32*32的低分辨率图像也会导致1024长度的序列,与单词序列相比,数据指数级增长,如果再加入注意力,则参数爆炸式增长。于是作者的策略是分阶段迭代提高分辨率,即增加输入序列同时逐渐降低维数。
鉴别器的任务是区分真假图像,也就是分类任务。作者设计了一个多尺度的鉴别器,在不同的阶段以不同大小的面片作为输入。(因为三种不同的序列能够同时提取语义结构和纹理细节。)
Self-attention虽然使生成器能够捕获全局对应关系,但在建模高的分辨率时,会出现超长序列,会极大影响效率,于是作者提出了Grid Self-Attention:
Grid Self-Attention将全尺寸特征映射划分为几个非重叠网格,网格内进行Self-attention(分成多个块,块内做标准的self-attention,然后将每个块相连)。
Grid Self-Attention在TransGAN中,只被运用在64×64以上分辨率以减少消耗,64以下的仍然采用标准的self-attention。这样的做法从战略上平衡局部细节和全局效率。
对比卷积来说,Transforme是更需要数据的,不同类型的强大数据增强可以为Transformer提供高效的训练。
作者从三个角度进行了数据增强:Translation, Cutout, Color,让TransGAN的性能有了惊人的提高。
Translation是做些许偏移,Cutout在图像上加一些纯白或者纯黑的像素点,Color就是改变图像的对比度、饱和度。
虽然经典的transformer已经有相对位置编码,但是其发挥出的作用不够明显。
作者将 Attention ( Q , K , V ) = softmax ( ( Q K T d k V ) \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\left(\frac{Q K^{T}}{\sqrt{d_{k}}} V\right)\right. Attention(Q,K,V)=softmax((dkQKTV)改为 Attention ( Q , K , V ) = softmax ( ( ( Q K T d k + E ) V ) \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\left(\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+E\right) V\right)\right. Attention(Q,K,V)=softmax(((dkQKT+E)V),其中E取自矩阵M,并作为残差项添加(M是同时考虑H轴和W轴,用表示相对位置的参数化矩阵 M ∈ R ( 2 H − 1 ) × ( 2 W − 1 ) M \in \mathbb{R}^{(2 H-1) \times(2 W-1)} M∈R(2H−1)×(2W−1))
相对位置编码学习了内容之间更强的“关系”,能够极大提升性能。
归一化层(Normalization )有助于稳定深层神经网络的深层学习训练,效果显著,原版标准归一化使用的是layer normalization,作者提出了一种 Y = X / 1 C ∑ i = 0 C − 1 ( X i ) 2 + ϵ Y=X / \sqrt{\frac{1}{C} \sum_{i=0}^{C-1}\left(X^{i}\right)^{2}+\epsilon} Y=X/C1∑i=0C−1(Xi)2+ϵ,其中 ϵ = 1 e − 8 {\epsilon}=1e-8 ϵ=1e−8,X和Y表示缩放层前后的标记,C代表嵌入维度。(类似于AlexNet中曾经使用的局部响应规范化)
CIFAR-10、STL10和CelebA数据集。
遵循WGAN的设置,并使用WGAN-GP损失, 生成器的batch大小为128,鉴别器的batch大小为64,选择DiffAug作为培训过程中的基本增强策略。评价指标使用IS和FID。