SimCLR:用于视觉表征的对比学习框架

论文链接: https://arxiv.org/pdf/2002.05709.pdf

官方github链接: https://github.com/google-research/simclr

他人复现pytorch链接: https://github.com/sthalles/SimCLR

1 概况

核心观点:

  1. 数据增强(data augmentations)的组合对预测任务的表现有重要影响,对于非监督学习而言,数据增强的提升作用更大;
  2. 本文定义了一个对比损失和表征之间的可学习非线性转换,大幅提高了表征的质量;
  3. 具有对比交叉熵损失(contrastive cross entropy loss)的表征学习得益于归一化嵌入和适当地调整温度参数;
  4. 与监督学习相比,对比学习可以通过更多的训练和更大的Batch Size 获得更好的表现,更深更宽的网络对对比学习表现的提升也有益。

框架效果:

  1. 在对Image Net 进行分类实验时,在top-1精度上取得了76.5%的结果,获比之前最先进的无监督或半监督提升7%,与有监督的Res-Net-50性能相媲美;

  2. 对Image Net 1%的Label 进行微调时,SimCLR 实现了85.8%的top-5精度,相对性能提升10%,超越了100× fewer label的AlexNet;

  3. 在其他自然图像分类数据集上进行微调时,SimCLR 在12个数据集中的10个数据集上的表现相当于或优于强监督的Baseline。

    SimCLR:用于视觉表征的对比学习框架_第1张图片

2 方法

2.1 对比学习框架

SimCLR 通过潜在空间上的对比损失,最大化相同数据示例的不同增强视图之间的协议进行表征学习,主要由四个主要组件组成:

  1. 随机数据增强模块

    将任意给定的数据示例随即转换为同一示例的两个相关视图,用 x ~ i \widetilde{x}_i x i x ~ j \widetilde{x}_j x j表示,将其视为一个正对。本文使用了三种方法进行数据增强:随即裁剪和调整(如随机翻转)(裁剪后调整图像尺寸为原图大小),随机色彩失真随机高斯模糊,作者认为随即裁剪和色彩失真的结合是使网络具有良好性能的关键。

  2. 神经网络基编码器(base encoder)

    基编码器定义为 f ( ⋅ ) f(·) f(),其作用在于从增强后的数据集中提取表征向量。作者认为SimCLR可以在无任何约束的情况下选择各种网络架构的基编码器,论文中选择的是常见的ResNet。因此,对于数据 x ~ i \widetilde{x}_i x i,有:
    h i = f ( x ~ i ) = R e s N e t ( x ~ i ) \boldsymbol{h}_i=f(\widetilde{x}_i)=ResNet(\widetilde{x}_i) hi=f(x i)=ResNet(x i)
    其中, h i ∈ R d \boldsymbol{h}_i\in\R^d hiRd,为经过平均池化层之后的输出。

  3. 小型神经网络投影头(projection head)

    投影头 g ( ⋅ ) g(·) g()的作用是将编码后的表征 h i h_i hi映射到应用对比损失的潜在空间中,本文使用的是一个两层MLP,具体计算方法如下:
    z i = g ( h i ) = W ( 2 ) ( σ ( W ( 1 ) ( h i ) ) ) z_i=g(h_i)=W^{(2)}(\sigma(W^{(1)}(h_i))) zi=g(hi)=W(2)(σ(W(1)(hi)))
    其中, σ \sigma σ是ReLU 函数,用 z i z_i zi h i \boldsymbol{h}_i hi更好构造对比损失。

  4. 对比损失函数

    在对比任务中,倘若给定一个包含正对 x ~ i \widetilde{x}_i x i x ~ j \widetilde{x}_j x j的数据集 { x ~ k } \left\{\widetilde{x}_k\right\} {x k},重构损失的作用是从 { x ~ k } k ≠ i \left\{\widetilde{x}_k\right\}_{k{\neq}i} {x k}k=i中找出给定 x ~ i \widetilde{x}_i x i对应的正对 x ~ j \widetilde{x}_j x j

因此,整个SimCLR的结构如下:

SimCLR:用于视觉表征的对比学习框架_第2张图片
  1. 样本构造:

    对于预测任务,随机采样一小批( N N N个))数据样本,之后采用数据增强方法,将每一个样本进行扩充,一共得到 2 N 2N 2N个数据点,即 N N N个正对。在本文中,不直接构造负对,而是对于给定的正对 { x ~ i , x ~ j } \left\{\widetilde{x}_i, \widetilde{x}_j\right\} {x i,x j},将除该样本对以外的其他 2 ( N − 1 ) 2(N-1) 2(N1)个样本示例都视为负样本。

  2. 损失函数构造:

    本文的对比损失函数基于 l 2 l_2 l2正则化,是从其他论文中参考得到的,称我为NT-Xent(标准化温度尺度的交叉熵损失)。该loss首先定义了一个sim函数表示正则化: s i m ( u , v ) = u T v / ∥ u ∥ ∥ v ∥ \rm{sim}(\boldsymbol{u},\boldsymbol{v})={\boldsymbol{u}^\mathsf{T}}\boldsymbol{v}/\|\boldsymbol{u}\|\|\boldsymbol{v}\| sim(u,v)=uTv/uv,正对 { x ~ i , x ~ j } \left\{\widetilde{x}_i, \widetilde{x}_j\right\} {x i,x j}之间的损失函数即为:
    ℓ i , j = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N l [ k ≠ i ] exp ⁡ ( s i m ( z i , z j ) / τ ) \ell_{i,j}=-\log\frac{\exp(\rm{sim}(\it{z_i},z_j\rm)/\tau)}{\sum_{k=1}^{2N}{\mathbb{l}_{[k{\neq}i] }\exp(\rm{sim}(\it{z_i},z_j\rm)/\tau)} } i,j=logk=12Nl[k=i]exp(sim(zi,zj)/τ)exp(sim(zi,zj)/τ)
    上式中, l [ k ≠ i ] \mathbb{l}_{[k{\neq}i]} l[k=i]是一个指示函数,当且仅当 k = i k=i k=i时取0,否则为1。 τ \tau τ是进行优化的温度参数,后续实验显示为0.1时效果最好。

  3. 算法概述

    SimCLR:用于视觉表征的对比学习框架_第3张图片

    根据算法可以发现,在构造数据时,对 x ~ i \widetilde{x}_i x i x ~ j \widetilde{x}_j x j分别采用一种增强方法,并且分别计算了一个正对中每个数据点对另一种增强方法所有数据的余弦相似性。

2.2 大Batch Size训练方法

作者提到,为了保持框架的简单化,在模型训练时不采用记忆库训练模型,而是采用大Batch Size的形式,将Batch Size从256扩大为8192,在进行数据增强之后,一个正对将有 ( 8192 − 1 ) × 2 = 16382 (8192-1)×2=16382 81921×2=16382个负示例。为了克服大Batch Size使用SGD/Momentum进行优化时可能导致的训练不稳定问题,论文对每个Batch Size都采用了LARS优化器。

此外,由于标准ResNet架构采用Batch Normlization(BN)进行规范化,作者指出,在具有数据并行性的分布式训练中,BN的均值和方差通常在每张卡上进行局部的聚合。而论文提出的对比学习中,正对在同一张卡中计算,模型可以利用局部信息泄漏提高预测精度,而不需要改善表征,他们在训练过程中使用了一种Global BN的方法,取所有卡上BN的均值和方差作为表示。

2.3 评价方案

  1. 数据集选择

    大部分实验都在ImageNet ILSVRC-2012 数据集上进行;少部分实验在CIFAR-10上进行;还使用了一些广泛用于迁移学习的数据。

    使用线性评估协议进行评估,在一个冻结基网络的顶层训练一个线性分类器,测试分类精度,也比较了一些迁移学习和半监督方法。

  2. 默认设置

    除非另有规定:

    在数据增强时使用随机裁剪和调整(随机翻转)颜色失真,和高斯模糊三种方法。

    使用ResNet-50作为基础网络编码器,并使用两层MLP将编码后的表征投影到一个128维的潜在空间之中。

    训练时,使用NT-Xent loss,使用LARS进行优化,设置学习率为 4.8 ( = 0.3 × B a t c h S i z e / 256 ) 4.8 (= 0.3×Batch Size/256) 4.8(=0.3×BatchSize/256),权值衰减率为 1 0 − 6 10^{-6} 106。Batch Size大小为4096,共训练100个epoch。此外,在前10个epoch使用线性预热,并在不重新启动的情况下使用余弦衰减计划衰减学习速率。

3 数据增强

3.1 数据增强涵盖了对比预测任务的两种情况

作者指出,目前数据增强方法尚未被视为对比预测任务的标准手段之一,目前流行的手段都是通过改变网络架构实现对比预测,作者认为他们的数据增强方法可以在数据层面即涵盖两种典型的对比预测任务,分别是整体——局部预测邻近预测

具体来说,如下图所示,在对图像应用随机裁剪时,将会出现如下图左右两张图的正对样本,这些样本包含了上述的两种对比预测任务情形(A, B: 整体——局部;C, D: 临近)。

SimCLR:用于视觉表征的对比学习框架_第4张图片

3.2 不同数据增强方式的组合对表征学习效果至关重要

作者通过下图展示了多种数据增强的方法,需要注意的是,在实验中,其只使用了随即裁剪(包括裁剪、大小调整和翻转)、颜色失真和高斯模糊三种手段。

SimCLR:用于视觉表征的对比学习框架_第5张图片

在实验过程中,作者通过研究单独和成对数据增强方法对SimCLR框架性能的影响。由于ImageNet数据库中的图像大小并不一致,在使用数据时通常已经进行了图像裁剪和调整操作,由于很难剔除随机裁剪的影响,因此作者通过不对称转换改善这一情况。

具体来说,首先对数据样本次啊用随机裁剪,并将其调整到相同的分辨率大小,之后对一个框架的两个分支中的一个应用目标转换 t ( ⋅ ) t(·) t(),而另一个分支作为标志,即令: t ( x i ) = x i t(x_i)=x_i t(xi)=xi

作者认为这种非对称数据的增加会对框架性能产生影响,但不应该在实质上改变单个数据增强或组合数据增强方法的影响。

作者测评了不同方法和组合方法的top-1识别结果,如下图所示。可以发现,单一数据增强方法都不足以学到良好的表征,即使模型可以识别出任务中的正对,组合增强方法会增大预测难度,但是能够显著提升表征质量,最后作者发现,随机裁剪和颜色失真的组合最有利于学习表征。

SimCLR:用于视觉表征的对比学习框架_第6张图片

3.3 对比学习比监督学习更需要数据增强

作者调整了颜色增强的强度,发现颜色增强大大提升了SimCLR的线性评估结果。而监督学习使用更加复杂的自动增强方法时,表征学习的表现也没有比简单的裁剪+颜色失真效果好,不同的颜色增强强度也没有提升或削弱监督学习的表现。因此作者认为对比学习比监督学习可以从颜色增强中获得更大收益。

SimCLR:用于视觉表征的对比学习框架_第7张图片

4 Encoder 和Head 的架构

4.1 更大的模型使无监督对比模型表现更好

下图中绿色“×”是有监督ResNet 训练90个epoch的结果,红色“★”是本文模型训练100个epoch的及如果,蓝色“·”是本文模型训练100个epoch的结果,可以发现,随着网络参数的增加,有监督和无监督模型的表现都有所提升,但无监督模型的提升更大。

SimCLR:用于视觉表征的对比学习框架_第8张图片

4.2 非线性投影head有助于提升其之前网络层的表征质量

非线性投影head,即前文提到的 g ( h ) g(\boldsymbol{h}) g(h),本文对比了包括非线性投影、线性投影以及无投影在内的三种映射方式对表征质量的影响,如下图所示。最终发现非线性投影的Top-1精度比无映射提升10%以上,比线性投影提升3%,并且这一结果在只使用一个head时不受输出维度的影响。

SimCLR:用于视觉表征的对比学习框架_第9张图片

此外,作者提到使用非线性投影时,head前面的隐藏层也可以学到更好的表征。

同时,作者通过对使用非线性投影后的映射 z z z 以及非线性投影之前的映射 h \boldsymbol{h} h对不同表征的预测表现进行对比,发现 z z z h \boldsymbol{h} h保留了更少的信息,具体如下表所示。作者认为是因为 z z z删除了一些颜色之类的信息,使得在训练过程中这些信息更多在 h \boldsymbol{h} h阶段形成。

SimCLR:用于视觉表征的对比学习框架_第10张图片

5 损失函数和Batch Size

5.1 带有可调温度参数的归一化交叉熵损失是更优选择

作者对比了本文使用的NT-Xent损失和下表中另外两种对比损失应用于此框架的表现。

SimCLR:用于视觉表征的对比学习框架_第11张图片

作者认为NT-Xent损失中的温度参数可以帮助模型从hard-negative中学习,而其他损失函数没有这个特点,必须引入额外的semi-hard-negative进行挖掘,在均使用l2正则化(余弦相似性)时,发现即使是用了semi-hard-negative的其他loss,表现也不如NT-Xent。

SimCLR:用于视觉表征的对比学习框架_第12张图片

在此基础上,作者又测试了使用l2正则化和只使用点积的情况以及不同温度参数对对比任务精度和通过Top-1预测的表征表现情况,发现不使用l2正则化会提升对比任务精度,但是降低了表征的线性评估结果。

SimCLR:用于视觉表征的对比学习框架_第13张图片

5.2 大Batch Size和更多训练对对比学习更有益

作者通过实验发现,大Batch Size 和大epoch 对框架的表现具有提升,因为增大这两者会提供更多的负例。不过随着epoch的增加,不同Batch Size效果之间的差异在减小。

SimCLR:用于视觉表征的对比学习框架_第14张图片

6 实验对比

6.1 线性评估

这部分实验用预训练好的一些无监督模型训练了一个线性分类器,评估在ImageNet上的效果,可以发现,SimCLR的Top-1和Top-5表现都很优秀,当加深网络时,无监督模型的表现普遍提升(括号中的(4×)的意思是网络宽度是原网络的4倍)。

SimCLR:用于视觉表征的对比学习框架_第15张图片

6.2 半监督学习

对ImageNet采样1%和10%有标签的训练集(分别是12.8张与128张图片一个类)进行微调,比较不同网络的效果。

微调过程使用名为Nesterov动量优化器进行调整,对于1%标记的数据,对60个epoch进行微调,对于10%标签的数据,对30个epoch进行微调,同时调整图像大小为256×256,之后在中央应用裁剪,裁剪到224×224的大小。

SimCLR:用于视觉表征的对比学习框架_第16张图片

6.3 迁移学习

SimCLR:用于视觉表征的对比学习框架_第17张图片

你可能感兴趣的:(Cantrastive,Learning,计算机视觉,神经网络,深度学习)