FlatNCE:小批次对比学习效果差的原因竟是浮点误差?

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第1张图片

©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

自 SimCLR [1] 在视觉无监督学习大放异彩以来,对比学习逐渐在 CV 乃至 NLP 中流行了起来,相关研究和工作越来越多。标准的对比学习的一个广为人知的缺点是需要比较大的 batch_size(SimCLR 在 batch_size=4096 时效果最佳),小 batch_size 的时候效果会明显降低,为此,后续工作的改进方向之一就是降低对大 batch_size 的依赖。那么,一个很自然的问题是:标准的对比学习在小 batch_size 时效果差的原因究竟是什么呢? 

近日,一篇名为 Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE 对此问题作出了回答:因为浮点误差。看起来真的很让人难以置信,但论文的分析确实颇有道理,并且所提出的改进 FlatNCE 确实也工作得更好,让人不得不信服。

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第2张图片

论文标题:

Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE

论文作者:

Junya Chen, Zhe Gan, Xuan Li, Qing Guo, Liqun Chen, Shuyang Gao, Tagyoung Chung, Yi Xu, Belinda Zeng, Wenlian Lu, Fan Li, Lawrence Carin, Chenyang Tao

论文链接:

https://arxiv.org/abs/2107.01152

细微之处

接下来,笔者将按照自己的理解和记号来介绍原论文的主要内容。对比学习(Contrastive Learning)就不帮大家详细复习了,大体上来说,对于某个样本 x,我们需要构建 K 个配对样本 ,其中 是正样本而其余都是负样本,然后分别给每个样本对 打分,分别记为 ,对比学习希望拉大正负样本对的得分差,通常直接用交叉熵作为损失:

简单起见,后面都记 。在实践时,正样本通常是数据扩增而来的高相似样本,而负样本则是把 batch 内所有其他样本都算上,因此大致上可以认为负样本是随机选择的 K-1 个样本。这就说明,正负样本对的差距还是很明显的,因此模型很容易做到 ,也即 。于是,当 batch_size 比较小的时候(等价于 K 比较小), 也会相当接近于 0,这意味着上述损失函数也会相当接近于 0。

损失函数接近于 0,通常也意味着梯度接近于 0 了,然而,这不意味着模型的更新量就很小了。因为当前对比学习用的都是自适应优化器如 Adam,它们的更新量大致形式为 梯度 梯度 梯度 学习率 ,这就意味着,不管梯度多小,只要它稳定,那么更新量就会保持着 学习率 的数量级。

对比学习正是这样的场景,要想 ,那么就要 ,但对比学习的打分通常是余弦值除以温度参数,所以它是有界的, 是无法实现的,因此经过一定的训练步数后,损失函数将会长期保持接近于 0 但又大于 0 的状态。

然而, 的计算本身就存在浮点误差,当 很接近于 0 时,浮点误差可能比精确值还要大,然后 的计算也会存在浮点误差,再然后梯度的计算也会存在浮点误差,这一系列误差累积下来,很可能导致最后算出来的梯度都接近于随机噪声了,而不能提供有效的更新指引。这就是原论文认为的对比学习在小 batch_size 时效果明显变差的原因。

变微为著

理解了这个原因后,其实也就不难针对性地提出解决方案了。对损失函数做一阶展开我们有:

也就是说,一定训练步数之后,模型相当于以 为损失函数了。当然,由于 ,即 是 的上界,所以就算一开始就以 为损失函数,结果也没什么差别,现在主要还是解决的问题是 接近于 0 而导致了浮点误差问题。刚才说了,自适应优化器的更新量大致上都是 梯度 梯度 梯度 学习率 的形式,这意味着如果我们直接将损失函数乘以一个常数,那么理论上更新量是不会改变的,所以既然 过小,那么我们就将它乘以一个常数放大就好了。

乘以什么好呢?比较直接的想法是损失函数不能过小,也不能过大,控制在 级别最好,所以我们干脆乘以 的倒数,也就是以:

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第3张图片

为损失函数。这里 是 stop_gradient 的意思(原论文称为 detach),也就是把分母纯粹当成一个常数,求梯度的时候只需要对分子求。这就是原论文提出的替代方案,称为 FlatNCE

不过,上述带 算子形式的损失函数毕竟不是我们习惯的形式,我们可以转换一下。观察到:

也就是说, 作为损失函数提供的梯度跟 作为损失函数的梯度是一模一样的,因此我们可以把损失函数换为不带 算子的 :

相比于交叉熵,上述损失就是在 运算中去掉了正样本对的得分 。注意到 通常可以有效地计算,浮点误差不会占主导,因此我们用上述损失函数取代交叉熵,理论上跟交叉熵是等效的,而实践上在小 batch_size 时效果比交叉熵要好。此外,需要指出的是,上式结果不一定是非负的,因此换用上述损失函数后在训练过程中出现负的损失值也不需要意外,这是正常现象。

实验评估

分析似乎有那么点道理,那么事实是否有效呢?这自然是要靠实验来说话了。不出意料,FlatNCE 确实工作得非常好。

原论文的实验都是 CV 的,主要是把 SimCLR 的损失换为 FlatNCE 进行实验,对应的结果称为 FlatCLR。其中,我们最关心的大概是 FlatNCE 是否真的解决了对大 batch_size 的依赖问题,下面的图像则作出了肯定回答:

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第4张图片

▲ 不同 batch_size 下 SimCLR 与 FlatCLR 对比图

下面则是 SimCLR 和 FlatCLR 在各个任务上的结果对比,显示出 FlatCLR 更好的性能:

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第5张图片

▲ SimCLR 和 FlatCLR 在各个任务上的对比

吹毛求疵

总的来说,原论文的结果非常有创造性,“浮点误差”这一视角非常“刁钻”但也相当精准,让人不得不点赞。

直观来看,原来交叉熵的目标是“正样本得分与负样本得分的差尽量大”,这对于常规的分类问题是没问题的,但对于对比学习来说还不够,因为对比学习目的是学习特征,除了正样本要比负样本得分高这种“粗”特征外,负样本之间也要继续对比以学习更精细的特征;FlatNCE 的目标则是“正样本的得分要尽量大,负样本的得分要尽量小”,也即从相对值的学习变成了绝对值的学习,从而使得正负样本拉开一定距离后,依然能够继续优化,而不至于过早停止(对于非自适应优化器),或者让浮点误差带来的噪声占了主导(对于自适应优化器)。

然而,原论文的某些内容设置也不得不让人吐槽。比如,论文花了较大的篇幅讨论互信息的估计,但这跟论文主体并无实质关联,加大了读者的理解难度。当然,paper 跟科普不一样,为了使文章更充实而增加额外的理论推导也无可厚非,只是如果能更突出浮点误差部分的分析更好。然后,论文最让我不能理解的地方是直接以式(3)为最终结果,这种带“stop_gradient”的表述方式虽然算不上难,但也不友好,通常来说这种方式是难以寻求原函数的时候才“不得不”使用的,但 FlatNCE 显然不是这样。

总结全文

本文介绍了对比学习的一个新工作,该工作分析了小批次对比学习时交叉熵的浮点误差问题,指出这可能是小批次对比学习效果差的主要原因,并且针对性地提出了改进的损失函数 FlatNCE,实验表明基于 FlatNCE 的对比学习确实能缓解对大 batch_size 的依赖,并且能获得更好的效果。

参考文献

[1] https://arxiv.org/abs/2002.05709

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

更多阅读

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第6张图片

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第7张图片

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第8张图片

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

???? 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

???? 投稿通道:

• 投稿邮箱:[email protected] 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

FlatNCE:小批次对比学习效果差的原因竟是浮点误差?_第9张图片

△长按添加PaperWeekly小编

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

你可能感兴趣的:(人工智能,深度学习,过拟合,办公软件,xhtml)