多模态系列论文--ALBEF 详细解析

ALBEF来自于Align before Fuse,作者团队全自来自于Salesforce Research。
论文地址:Align before Fuse: Vision and Language Representation Learning with Momentum Distillation
论文代码:ALBEF

1 摘要

最近图像文本的大规模的特征学习非常火爆,大部分已有的方法都是用一个Transformer模型作为多模态的一个编码器,同时编码视觉的Token和文本的Token,视觉Token就是视觉特征,一般是region-based的图像特征。
本文贡献1–ALign BEfore Fuse:
ViLT和ALBEF都认为不需要目标检测的模型,但ViLT只是说用了目标检测模型以后速度太慢,希望推理时间变得更快,但是ALBEF分析认为使用预训练的目标检测器之后,视觉特征和文本特征无法align对齐,因为目标检测器是提前训练好的,只抽取特征,没有再进行end-to-end的训练,所以导致视觉特征和文本特征可能相隔很远,此时将两个特征同时送入多模态的编码器之后,编码器学习图像文本之间的交互信息就会变得很有挑战。作者提出对比学习的ITC Loss,将图像和文本在Fusing之前实现Align,也就是他们论文的题目ALign BEfore Fuse。
本文贡献2–自训练方式学习:
网上爬取的数据特别noisy,因为图片文本对里面的文本通常具备搜索属性,但不是好的描述性句子,甚至没有描述,这种noisy web data的问题影响模型有效地学习文本图像特征,所以作者使用Momentum Distillation也就是自训练的方式去学习,自训练就是用Pseudo Label伪标签去学习。作者采用Moco提出的Momentum Encoder的形式生成Pseudo Target,从而达到一个自训练的效果。
在图文检索的任务上,ALBEF的效果最好,性能反超之前的方法,尤其是在特别大的数据上训练过的模型CLIP和Align。在VQA和VR任务上,ALBEF也比之前state of art提升2.37%、3.84%的准确度,而且推理时间也更快。模型在4 million的训练数据集下,能做到一个8卡A100训练三四天的时间,训练门槛大大降低。

2 相关工作

多模态系列论文--ALBEF 详细解析_第1张图片
上图是早期工作的一个简单总结:

  1. 图 (a) 是VSE或者VSE++的架构,它们的文本端就是直接抽一个文本特征,但是它们的视觉端非常大,需要的计算量非常多,因为它是一个目标检测器。当得到了文本特征和视觉特征之后,它最后只能做一个很简单的模态之间的交互,从而去做多模态的任务。
  2. 图(b)是CLIP的结构,视觉端和文本端都用同等复杂度的encoder进行特征提取,再做一个简单的模态交互,结构优点是对检索任务极其有效,因为它可以提前把特征都抽好,接下来直接算Similarity矩阵乘法就可以,极其适合大规模的图像文本的检索,非常具有商业价值。缺点是只计算Cosine Similarity无法做多模态之间深度融合,难一些的任务性能差。
  3. 图(c)是Oscar或者ViLBERT、Uniter采用的架构,因为对于多模态的任务,最后的模态之间的交互非常重要,只有有了模态之间更深层的交互,VQA、VR、VE这些任务效果才会非常好,所以他们就把最初的简单的点乘的模态之间的交互,变成了一个Transformer的Encoder,或者变成别的更复杂的模型结构去做模态之间的交互,所以这些方法的性能都非常的好,但是随之而来的缺点也很明显:所有的这一系列的工作都用了预训练的目标检测器,再加上这么一个更大的模态融合的部分,模型不论是训练还是部署都非常的困难
  4. 图 (d) 是ViLT的架构。当Vision Transformer出来之后,ViLT这篇论文就应运而生了,因为在Vision Transformer里,基于Patch的视觉特征与基于Bounding Box的视觉特征没有太大的区别,它也能做图片分类或者目标检测的任务,因此就可以将这么大的预训练好的目标检测器换成一层Patch Embedding就能去抽取视觉的特征,所以大大的降低了运算复杂度,尤其是在做推理的时候。但是如果文本特征只是简单Tokenize,视觉特征也只是简单的Patch Embedding是远远不够的,所以对于多模态任务,后面的模态融合非常关键,所以ViLT就直接借鉴 © 类里的模态融合的方法,用一个很大的Transformer Encoder去做模态融合,从而达到了还不错的效果。因为移除了预训练的目标检测器,换成了可学习的Patch Embedding Layer。
    ViLT的优点:
    模型极其简单。它虽然是一个多模态学习的框架,但跟NLP的框架没什么区别,就是先Tokenized,然后送到一个Transformer去学习,所以非常的简单易学。
    ViLT的缺点:
    1)它的性能不够高,ViLT在很多任务上是比不过 © 类里的这些方法的,原因是对于现有的多模态任务,需要更多的视觉能力(可能是由于数据集的bias),因此视觉模型需要比文本模型要大,最后的效果才能好,但是在ViLT里,文本端用的Tokenizer很好,但是Visual Embedding是Random Initialized,所以它的效果就很差
    2)ViLT虽然推理时间很快,但它的训练时间非常慢,在非常标准的一个4 million的数据集set上,ViLT需要64张32G的GPU训练三天,它训练的复杂度和训练的时间丝毫不亚于 © 类的方法,所以它只是结构上简化了多模态学习,但训练难度并没有降低。

通过上面的分析可以对模型和LOSS函数的选择有所启发:

  1. 模型的结构选择:
    1)因为有图像的输入和文本的输入,模型应有两个分支分别抽取图像文本特征。
    2) 在多模态学习里,视觉特征重要性远远要大于这个文本特征,所以应该使用更大更强的视觉模型。
    3)多模态学习模态之间的融合也非常关键,因此需要模态融合的模型尽可能大,所以好的多模态学习网络结构应该像 ©,也就是文本编码器比图像编码器小,多模态融合的部分尽可能大。
  2. Loss选择:
    1) Image Text Contrastive Loss(ITC):CLIP模型用一个对比学习的ITC Loss训练,效果就很好,而且训练也很高效,建议使用。
    2) Mask Language Modeling(MLM):BERT的训练方式,遮住一个词然后再去预判这个词(完形填空)
    这个Loss非常有用,目前不论NLP还是Vision,都大一统的用Mask Modeling,建议使用
    3)Image Text Matching Loss(ITM):在之前的 © 和 (d) 方法中都取得了很好的结果,建议使用。
  3. Loss不选择WPA:
    对于(c) 类的这种方法,使用目标检测,所以使用Word Patch Alignment Loss,就是文本的一个单词和图像上的一个Patch应该有一个对应关系,现在不用目标检测模型,而且在ViLT里WPA Loss计算非常慢,模型训练麻烦,不建议使用。

3 ALBEF预训练

3.1 网络结构

多模态系列论文--ALBEF 详细解析_第2张图片

  1. 图像端:给定任何一张图片,按照Vision Transformer的做法,把它打成patch,然后通过patch embedding layer送给一个Vision Transformer。这里是一个非常标准的12层的Vision Transformer的base模型,如果图片是224x224,那这里的sequence length就是196,然后加上额外的一个CLS token就是197,它的特征维度是768,所以这里绿黄色的特征
    就是197乘以768。但论文在预训练阶段用的图片是256x256,所以这里绿色的sequence length就会相应的再长一些。它的预训练参数用的DEiT,也就是Data Efficient Vision Transformer在ImageNet 1K数据集上训练出来的初始化参数。
  2. 文本端:文本端为保持计算量与clip类似,并且增强模态融合的部分,用前六层做文本编码,剩下的六层transformer encoder作为multi-model fusion的过程。文本模型用BERT模型做初始化,它中间的特征维度也是768,它也有一个CLS token代表了整个句子的文本信息。
  3. momentum model:ALBEF为了做momentum distillation,而且为了给ITC loss提供更多negative,增加了momentum model,这个模型参数由左边训练的模型参数通过moving average得到的(和MoCo一模一样),通过把moving average的参数设的非常高(论文里是0.995)来保证momentum model不会那么快更新,产生的特征更加稳定,不仅可以做更稳定的negative sample,而且还可以去做momentum distillation。

3.2 损失函数:

  1. ITC loss:
    对比学习就是首先定义一个正样本对,然后定义很多负样本对,对比使正样本对之间的距离越来越近,正负样本对之间的距离越来越远。在ALBEF里首先将图像I通过vision transformer之后得到图像的全局特征,图中黄色CLS token作为全局特征,也就是一个768×1的向量,文本这边先做tokenize,将文本text变成tokenization的序列,再输入BERT的前六层,得到了一系列的特征,文本端的CLS token作为文本的全局特征,也是一个768×1的向量。接下来与MoCo相同,图像的特征向量先做一下downsample和normalization,将768×1变成了256×1的向量,文本特征向量也是768变成256×1。正样本这两个特征距离尽可能的近,它的负样本全都存在一个q里,含有65536个负样本(因为它由momentum model产生的,没有gradient,所以它并不占很多内存),正负样本之间的对比学习,使得这两个特征距离尽可能的远。这个过程就是align before fuse的align,也就是说在图像特征和这个文本特征输入Multi-model Fusion的encoder之前,就已经通过对比学习的ITC loss让这个图像特征和文本特征尽可能的拉近,在同一个embedding space里,具体使用了cross entropy loss。

  2. ITM loss:Image Text Matching,属于一个二分类任务,就是给定一个图片,给定一个文本,图像文本通过ALBEF的模型之后输出一个特征,在这个特征之后加一个分类头,也就是一个FC层,然后去判断I和T是不是一个对,
    这个loss虽然很合理,但是实际操作的时候发现这个loss太简单,所以这个分类任务,很快它的准确度就提升得很高无法进一步优化。因此ALBEF通过某种方式选择最难的负样本(最接近于正样本的那个负样本),具体来说ALBEF的batch size是512,所以ITM loss正样本对就是512个,对于mini batch里的每一张图像,把这张图片和同一个batch里所有的文本都算一遍cos similarity,然后它在这里选择除了它自己之外相似度最高的文本当做negative,这样ITM loss就非常有难度

  3. MLM:Mask Language Modeling,它把原来完整的句子text T变成一个T’,也就是有些单词被mask掉了
    然后它把缺失的句子和图片一起通过ALBEF的模型,最后把之前的完整的句子给预测出来,它这里也借助了图像的信息去更好的恢复被mask掉的单词。

计算ITC loss和ITM loss的时候,输入都是原始的i和t,但是当计算MLM loss的时候,它的输入是原始的i和mask后的T’,所以模型每一个训练的iteration都做了两次模型的forward:一次模型的forward用了原始的i和t,另一次模型的forward用了原始的i和mask后的T’。为了计算不同的loss,这种多次前项在多模态学习中使用普遍。
在这里插入图片描述
公式5是ALBEF的目标函数,也就是itc、mlm和itm的合体

3.3 momentum distillation动量蒸馏

  1. 动机:
    由于从网上爬下来的数据噪声严重,图像文本对经常是弱相关,甚至不匹配,这种noisy的data会导致计算目标函数时有偏差。比如在算ITC的时候,某个负样本文本很有可能也描述了图像里的很多内容,它可能不是爬下来的image text pair,但文本甚至可能比Ground Truth描述的还好。若此时将其作为一个负样本,就会对ITC的学习造成很大的影响。
  2. 改进方式:
    作者认为one hot label(就是图片和文本就是一对,其他跟它都不是一对)对于ITC和MLM这两个loss来说不好,因为有的负样本也包含了很多的信息。所以作者采取了自训练方式,先构建一个momentum model然后用这个动量模型去生成pseudo targets伪目标(其实就是一个softmax score),这样它就不再是一个one hot label。
    动量模型在已有模型之上做exponential moving average EMA,目的是在原始模型训练的时候,不仅希望模型预测与ground truth的one hot label去尽可能的接近,还希望模型预测与动量模型出来的pseudo targets尽可能的匹配,这样就能达到一个比较好的折中点。因为当one hot label正确时,可以学习到很多信息,但当one hot label是错误的,或者是noisy的时候,作者希望稳定的momentum model能够提供一些改进。
    以ITC loss为例,它是基于这个one hot label的,所以这里再算一个pseudo target loss去弥补它的一些缺陷和不足。这个loss跟前面equation1里的ITC loss的不同,这里将ground truth换成这个q,就是pseudo targets,q不再是one hot label,而是softmax score,所以这里计算KL divergence而不是cross entropy。
  3. 最终结果:
    最终ALBEF的训练loss有五个:两个ITC、两个MLM、一个ITM,其中:
    1)ITC有两个loss:一个是原始的ITC,一个是基于pseudo target的ITC,所以分别加权(1-α)和α的loss weight,最终得到momentum版本的ITC loss。
    2)MLM loss有两个loss:一个是原始的MLM,一个是基于pseudo target的MLM,用新生成的pseudo target去代替了原来的ground truth,分别加权(1-α)和α的loss weight,最终得到momentum版本的MLM loss。
    3)ITM有一个loss:ITM没有动量版本,因为本身它就是基于ground truth,它就是一个二分类任务,而且在ITM里又做了hard negative,这跟momentum model有冲突。

3.4 预训练用的数据集

本文与UNITER论文相同,用了四个数据集来预训练,分别是Conceptual Captions、SBU Captions、COCO和Visual Genome,具体如下表,其中,CC数据集原始版本是300万和1200万,但下载时候会有数据缺失的问题,所以数据集差几万个可能影响小,但差几十万个可能影响算法的效果,并且每个人下载下来的数据集可能都不完全一致。
表格
ALBEF的预训练数据集有两个设置:一个是4 million的setting有四个数据集,一个就是这个14 million的setting,加入CC12 million。这两个setting也是之前最常用的,不过现在预训练数据集规模已经越来越大,开源的LAION已经有了400 million,而且已经有了2 billion 5 billion这么大的数据集。

3.5 训练细节

模型由123.7M的BERTbase 和一个 85.8M的 ViT-B/16组成,训练的batchsize是512,使用8 个NVIDIA A100 的GPU训练30个epochs,使用AdamW优化器, 0.02的weight decay。在前1000轮迭代中,学习率预热到1e−4,然后按照余弦函数衰减到1e−5,在预训练阶段,随机裁剪256 × 256图像作为输入,使用了RandAugment数据增强,在fine-tuning阶段,图像分辨率增加到384 × 384,并且按照文献38的方法对图像patches的位置编码进行插值,更新momentum model的momentum参数设置为0.995,图像文本对比学习的序列大小设置为 65,536,在第一个epoch 将蒸馏权重 α 从0升到 0.4。

4 互信息最大化视角

作者通过互信息最大化的角度,做了一些理论分析,结论就是文章里的这些训练的目标函数,包括LIM、ITM、还有momentum distillation,它们最终的作用都是为同一个图像文本对生成不同的视角,也就是在做一种data augmentation,从而让最后训练出来的模型具备Semantic Preserving的功能,就是只要是语义匹配的图像文本对,它就应该被当成是一对。
这样ALBEF不需要Bounding Box Annotation,也就不需要预训练的目标检测,自然输入图像也就不需要是目标检测需要的高分辨率的图像,这个CLIP和ViLT已经有了。

5 下游Vision Language任务

任务说明

ALBEF这篇论文做了五个任务:

  1. 图文检索(图像到文本、文本到图像、图像到图像、文本到文本):
    定义:给定一个数据库,去搜索Ground Truth的图像文本对。
    衡量指标:因为是检索,衡量的指标就是这个Recall 召回率,一般用的是R1,R5,R10,就是说在你检索回来的一个、五个或者十个Sample里是否有Ground Truth Sample,有就算正确。
  2. 视觉蕴含(Visual Entailment):
    定义:给定一个假设或者一个前提,能否推理出这个前提,如果能推理出来,就是一个蕴含的关系Entailment。如果前后矛盾推不出来,就是矛盾Contradictory。如果没关系,不确定能否推出来,就是中立Neutral。
    衡量指标:一般将Visual Entailment变成一个三分类的问题,衡量的指标就是分类准确度。
  3. 视觉问答(VQA):
    定义:给定一个问题一个图片,能否提供一个Answer回答这个问题,一般有两个Setting:一个是看作分类问题,它的这些答案都是固定的,从一个集合中选择一个正确答案,一般称作闭集VQA,例如VQA2.0数据集有一个提前设好的3192个Answer。一个是看作生成问题,需要一个Transformer Decoder做文本生成的任务,一般称作开集VQA,开集VQA它的任务难度大很多,因为有可能生成了正确的或者很相似的答案,但是它跟Ground Truth不一致,还是会被判错。
    衡量指标:分类准确度。
  4. 视觉推理(Visual Reasoning):
    定义:预测一个文本能不能描述一张图片,所以它是一个二分类任务问题
    衡量指标:分类准确度。
  5. Visual Grounding,但其实Visual Grounding属于它自己的一个领域,很多多模态表征学习的论文里, 都不会去涉及这个的任务
    ALBEF虽然做了Answer Generation的问题,但在做推理的时候,还是把生成的答案限制到3192个答案里,所以还是一个分类问题。

消融实验

作者将用了MLM和ITM training loss的模型当做了baseline,因为基本上所有的之前的工作和现在的工作里面都会有这两个loss。

  1. 增加ITC也就是Align Before Fuse里的Align,这个提升是非常巨大的,基本是两个多点,三个点,而且是在检索、VE、VR、VQA这么多任务上都有明显的提升,所以CLIP、MoCo这种对比学习的方式非常厉害
  2. 增加ITM里提出的Hard Negative,大概都有0.5左右的提升
  3. 增加Momentum Distillation(MoD),在预训练阶段的MoD提升可能只有0.3 0.4个点,但是这个研究方向还是很好的,就是怎么从Noise Data里去学习有效的表征。
    多模态系列论文--ALBEF 详细解析_第3张图片

各任务测试结果

表2是fine-tuned的图文检索结果,表3 是zero-shot的图文检索结果。
多模态系列论文--ALBEF 详细解析_第4张图片
Zero-shot的图文检索就是先预训练好一个模型,接下来就直接抽取特征,然后计算Cosine Similarity就可以直接检索,不需要在下游任务上fine-tune。但需要注意,本文在预训练完之后还在COCO那个数据集上又做了一遍有监督的fine-tune,但是CLIP和Align其实都是直接做Zero-shot。而且只是在Flickr30K上做了Zero-shot。
图文检索的数据集COCO、Flickr30K都已经刷的非常高了,所以这个领域需要新的更大的数据集了。
多模态系列论文--ALBEF 详细解析_第5张图片
VQA、VR、VE三个任务都变成了一个分类问题,所以都用准确度来衡量。表4中跟之前最强的OSCAR比,ALBEF 4M其实就已经比它都要好了。

总之 ALBEF不论是在训练速度还是在推理速度上,还是在通用性或者在性能表现上,当时都是非常亮眼的工作。

你可能感兴趣的:(多模态论文,多模态)