点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”
作者:Abe Fetterman , Josh Albrecht
编译:ronghuaiyang
导读
对自监督学习中学习的本质进行剖析,扩展了对比学习中负样本的概念,并不是一定要使用不同的图像才可以产生负样本,batch normalization同样也可以产生负样本。
概要
与之前工作SimCLR和MoCo不同,最近的一篇来自DeepMind的论文"Bootstrap Your Own Latent" 展示了一个先进的自监督学习的方法,不需要明确的对比损失函数。通过消除损失函数中对负样本的需要,简化了训练。我们复现BYOL的时候,强调了两个令人惊讶的发现:
(1)在删除batch normalization时,BYOL的性能通常不比random好
(2)batch normalization的存在隐式地导致了一种对比学习的形式
这些发现强调了在学习表征时正样本和负样本对比的重要性,并帮助我们更基本地理解自监督学习是如何以及为什么有效的。
论文代码:https://github.com/untitled-ai/self_supervised
机器学习通常是在“监督”的方式下完成的:我们使用一个由输入和“正确答案”(输出)组成的数据集,来找到从输入数据映射到正确答案的最佳函数。相比之下,在自监督学习中,在数据集中没有正确的答案,因此我们学习一个函数,把输入数据映射输入到本身(比如:使用图像的右半部分预测图像的左半部分)。
从语言到图像和音频,这种方法已经被证明是成功的。事实上,最近的语言模型,从word2vec到BERT和GPT-3,都是自监督方法的例子。最近,这种方法在音频和图像方面也取得了一些令人难以置信的成果,而且[一些人相信](https://cacm.acm.org/news/244720-yann-lecunyoshua-bengio -self- learning- isketo -human-级智能/全文),它可能是类人智能的一个重要组成部分。这篇文章关注的是图像表示的自监督学习。
在BYOL发布之前,性能最好的算法是MoCo和SimCLR。MoCo和SimCLR都是对比学习的例子。
对比学习是训练分类器区分“相似”和“不同”输入数据的过程。具体到MoCo和SimCLR,分类器的正样本是同一幅图像的不同的修改版本,负样本是同一数据集中的其他图像。假设有一张狗的图片。在这种情况下,正样本可以是对该图像的不同的crop,而负样本可以是来自完全不同的图像的crop。
狗的原始图片的增强版本(a)。其中任何两个都可以用作是正样本对。MoCo和SimCLR在其损失函数中使用正样本和负样本的对比学习,而BYOL在损失函数中只使用正样本。乍一看,BYOL似乎在进行自监督学习的时候完全没有对比不同的图像。然而,BYOL能起作用的主要原因似乎是它确实是在做一种对比学习 —— 只不过是通过一种间接的机制。
为了更深入地理解BYOL中的这种间接的对比学习,我们应该首先回顾一下这些算法是如何工作的。
SimCLR是一个特别优雅的自监督算法,它简化了以前的方法,使其成为核心,并提高了性能。同一幅图像x的两个变换v和v ' 通过同一个网络产生两个投影z和z '。对比损失的目的是使同一个输入x的两个投影的相似度最大化,同时使同一个minibatch内其他图像的投影相似度最小化。继续使用我们的狗的例子,对同一张狗的图像进行不同的crop投影要比从其他随机图像中的crop在同个batch中更相似。
在SimCLR中用于投影的多层感知器(MLP)在每个线性层之后使用batch normalization。
SimCLR结构相对于SimCLR, MoCo v2能够减少batch size(从4096减少到256)并提高性能。不像SimCLR,两个网络共享相同的参数、MoCo将单一网络分为一个online network参数为θ,和momentum network参数ξ。online network采用随机梯度下降法进行更新,momentum network采用online network的指数移动平均法进行更新。online network允许MoCo将之前的投影放入memory bank中以进行高效的利用,并作为对比损失的负样本。这个memory bank支持更小的batch size。在我们的狗图像的示意图中,正样本是同一副狗的图像的不同的crops,负样本是在过去的mini-batch中使用过的完全不同的图像,这些图像存储在memory bank中。
在MoCo v2中用于投影的MLP不使用batch normalization。
MoCo v2 结构,上面是online编码器,下面是momentum编码器BYOL建立在MoCo的momentum network概念上,增加了一个MLP(qθ),用来从z中预测z ',而不是使用对比损失,BYOL使用了L2来计算归一化预测p和目标z '之的误差。以我们的狗图像为例,BYOL尝试将狗图像的两种crop转换为相同的表示向量(使p和z '相等)。因为这个损失函数不需要负样本,所以BYOL中不需要使用memory bank。
BYOL中的两个MLPs只在第一个线性层之后使用batch normalization。
BYOL结构根据上面的描述,BYOL似乎可以在不明确地对比多个不同图像的情况下学习。然而,令人惊讶的是,我们发现BYOL不仅在做对比学习,而且对比学习对它的成功是必不可少的。
我们最初使用为MoCo编写的代码在PyTorch中实现了BYOL。当我们开始训练我们的网络时,我们发现我们的网络的表现并不比random好。将我们的代码与[另一个实现:https://github.com/sthalles/PyTorch-BYOL进行了比较,我们发现MLP中缺少了batch normalization。我们很惊讶batch normalization化对于BYOL的训练是至关重要的,而MoCo v2根本不需要它。
对于我们的初始测试,我们使用带有动量的SGD,batchsize为256的STL-10无监督数据集训练了一个使用BYOL的ResNet-18。下面是在MLPs中使用和不使用batch normalization的同一个BYOL算法的前10个epochs的训练。
在STL10上ResNet-18的早期训练中验证集的精度基本是线性的。在MLP中不进行批处理归一化的BYOL训练时,其性能并不比随机基线好。为了调查性能发生这种戏剧性变化的原因,我们执行了一些额外的实验。
使用对比损失的实验配置,更好的和BYOL结果进行比较因为与MoCo相比预测的MLP q改变了网络深度,我们想知道是否需要batch normalization来规范这个网络。也就是说,虽然MoCo不需要batch normalization,但是当与额外的预测MLP q配对时,MoCo可能需要batch normalization。为了测试这一点,我们开始用一个对比损失函数来训练上面的网络。我们发现,在10个epoch内,该网络的性能明显优于随机网络。这个结果让我们怀疑没有使用对比损失函数会导致训练依赖于batch normalization。
然后,我们想知道另一种类型的规范化是否会有同样的效果。我们对MLPs应用了Layer Normalization而不是batch normalization,并使用BYOL对网络进行了训练。在MLPs未进行归一化的实验中,其性能并不比随机的好。这个结果告诉我们,在同一个小batch中激活其他输入对于帮助BYOL找到有用的表示是至关重要的。
接下来,我们想知道是否在投影MLP g、预测MLP q或两者中都需要batch normalization。我们的实验表明批batch normalization在投影MLP中是最有用的,但是网络可以通过batch normalization在任一MLP中学习到有用的表示。在MLPs中只需要一个batch normalization就足以让网络学习了。
总结一下到目前为止的发现:在缺乏对比损失函数的情况下,BYOL训练的成功取决于与minibatch中其他输入的激活相关的batch normalization层。
在对比损失函数中使用负样本的一个目的是防止模式坍塌。模式坍塌的一个例子是一个网络总是输出[1,0,0,0,…]作为它的投影向量z。如果所有的投影向量z都相同,那么网络只需学习q的恒等函数就可以达到完美的预测精度!
batch normalization的重要性在此上下文中变得更加清晰。如果在投影层g中使用batch normalization,那么投影输出向量z就无法坍缩成任何奇异值,如[1,0,0,0,…],因为这正是batch normalization所避免的。无论如何相似的输入经过batch normalization,输出将根据学习的平均值和标准偏差重新分配。这样就精准地防止了模式坍塌,因为在batch normalization,所有的minibatch的样本无法取相同的值。
batch normalization在预测MLP中可以产生类似的效果。如果minibatch的输入非常相似,q函数就无法学习identity函数:batch normalization将通过向量空间重新分配激活,因此最后一层的预测都非常不同。如果这些向量z '在表示空间中被充分地分离(也就是说,没有坍塌),这个函数只会在预测投影向量z '时成功,因为预测的p被在minibatch中可以被很好地分离。
我们的发现似乎与一个简单的结论一致:防止模式坍塌的一种方法是识别样本之间的共同模式。batch normalization在minibatch之间标识这种共同的模式,并通过使用minibatch中的其他表示形式(如隐式负样本)来删除它。因此,我们可以把batch normalization看作是在嵌入式表示上实现对比学习的一种新方法。
换句话说,通过batch normalization,BYOL通过提问来学习,“这张图像与平均图像有什么不同?”SimCLR和MoCo使用的显式对比方法是这样学习的:“这两个特定图像之间的区别是什么?”这两种方法似乎是相同的,因为将一幅图像与许多其他图像进行比较,其效果与将它与其他图像的平均值进行比较是一样的。例如,[prototypical contrastive learning](https://blog.einstein.ai/prototypical-contrastive-learing-pushing-fronties-of-unsupervised -learning/)就利用了这种等价性。
假设上述情况属实(删除batch normalization会导致BYOL模式坍塌)。在这种情况下,我们应该期望看到所有的表示和投影(z, z '和p向量)是相等的 —— 这就是我们所看到的。
在训练了上述每个变量后,我们测量了第一个输入投影向量z与第二个输入投影向量z’的余弦相似度。在训练的第10个epochs,我们测量了正样本投影(蓝色部分)与负样本投影(红色部分)之间的平均余弦相似度。
在g或q中没有batch normalization,投影与正样本和负样本高度对齐(0.9999),这表明将图像表示压缩为了同一个公共向量。因为batch normalization没有引入对比学习,它也导致了正样本的和负样本的对齐表示。对于标准的BYOL训练(即使用batch normalization),我们得到了预期的不同向量。正样本之间的投影(0.88)比负样本子之间的投影(0.27)更相似。
投影z和z'之间的平均余弦相似度。下方(蓝色)条是同一幅图像x的投影之间的相似性,上方(红色)条是同一小批中不同图像的投影之间的相似性。非MLP归一化实验和层归一化实验的所有表示具有很高的相似性,表明模态发生了坍塌。这些结果支持我们对batch normalization的理解,即隐含地引入了使用minibatch统计的对比学习。
到目前为止,我们只看了训练的前10个epochs。当我们训练的时间更长时,我们发现ResNet编码器中的batch normalization层与MLPs中的具有相似的效果。在编码器中(而不是MLPs中)进行batch normalization后,网络首先学习了坍塌了的表示的函数,然后逐渐开始从正样本中分离出负样本。
当我们从ResNet编码器中删除批处理归一化并使用SGD训练网络时,它无法学习任何东西(正是我们上面描述的原因)。
然而,当我们联系到作者时,他们友好地指出,我们并没有使用与BYOL原始论文完全相同的设置。通过从SGD切换到分层学习率自适应(LARS)或增加权值衰减,我们的网络能够再次学习(尽管性能显著下降)。
我们研究了每一种技术,发现它们只是防止模式坍塌的替代方法。此外,它们自身的健壮性明显较差 —— 它们依赖于仔细的超参数调优,如果不调优,它们很容易出现模式坍塌并且性能糟糕。因此,我们得出结论,batch normalization似乎是防止BYOL中模式坍塌的最健壮的技术。
有趣的是,即使在损失函数中没有负样本,BYOL中的batch normalization可以隐式地引入了对比学习。这一发现在事后看来是有意义的 —— 没有学习什么时候有模式坍塌,而batch normalization使模式崩溃变得不可能!无论是不同的图像彼此对比或对比每一图像与所有图像的平均值,学习的主要部分是理解事物之间的差异。
除了阐明batch normalization如何在对比学习中工作之外,这还可以作为一个教训,说明batch normalization可能会产生意想不到的副作用。通过batch normalization,网络不再是纯粹的学习与输入相应的输出的函数了。由于这个和其他原因,在训练中可能需要避免batch normalization。我们建议其他从业者也许应该默认使用其他替代方案,比如layer normalization或weight standardization with group normalization。
对未来的工作来说,相反的方向也是一个有趣的途径。与其因为这种隐式对比效果而避免batch normalization,不如直接利用它,允许在层中(而不是最后一层)进行隐式对比学习,这可能会很有趣。一个有趣的开放问题是,在训练神经网络中,batch normalization的成功有多少是由内部表示的分离直接造成的。
最后,我们发现BYOL(使用正确的超参数)甚至在没有显式的对比损失或通过batch normalization的隐式对比机制的情况下也能学到一些东西,这一点很有趣。虽然我们不建议任何从业者在实践中使用这些网络,我们认为这是一个新颖有趣的贡献,这个行为可能会提供一个有价值去了解为什么这些技术(weight decay, weight standardization以及LARS)非常有效。
—END—
英文原文:https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html
请长按或扫描二维码关注本公众号
喜欢的话,请给我个在看吧!