自监督对比学习框架SimCLR原理

目录

一、前言

人工智能发展近况

对比学习

二、数据集介绍

STL-10数据集

三、无监督图像表征对比学习

SimCLR

SimCLR算法基本原理

数据增强与正负样本匹配

编码器

损失函数

对比学习全过程

四、有监督的图像下游任务迁移

替换下游任务网络层

有监督训练

五、实验

无监督训练

有监督训练

六、结论


一、前言

人工智能发展近况

2023年,由于transformer等模型的应用,人工智能领域的研究迎来了爆炸式的增长。3月8日,微软发布了Visual Chat GPT,在聊天时不光可以使用文字,还可以使用图片;3月9日,Adobe团队发布10亿参数的大模型GigaGAN,效果不逊色于基于扩散模型的DALLE2;3月13日,斯坦福大学发布700万参数模型Alpaca,可与OpenAI17亿参数的Text-davinvi-003模型相媲美;3月14日,GPT-4发布,GPT-4由语言模型改进为多模态模型,可以接收图片输入;3月15日,Midjourney发布第五代文生图模型,效果惊人。

随着人工智能技术的飞速发展,越来越多的研究人员、企业和组织都开始投入大量精力和资源来训练和构建更加复杂和庞大的机器学习模型。这些大模型通常具有比小模型更高的准确率和鲁棒性,然而,它们需要更大规模、更多样化和更丰富的数据集进行训练。

对于一些常见的任务,例如图像分类、语音识别以及自然语言处理等领域,为了训练一个高质量的大模型,需要耗费无数的时间和金钱来获取高质量的标注数据集。因此,数据标注代价也在不断攀升,成为了进一步推动人工智能技术发展的一个瓶颈。

对比学习

为了解决标注数据成本高昂,难以获取足够的有标签数据的问题,自监督学习的方法也被广泛应用于大模型的训练中。

对比学习(Contrastive Learning)是近年来自督学习领域中备受关注的一种方法。它通过学习样本及其变换的表示之间的关系,用无标签数据来提取出良好的特征表示。相比于传统的监督学习方法,对比学习不依赖于代价高昂的人工标注,可以很好的将无标签数据与有标签数据结合利用,对于模型性能的提升和大规模数据应用具有重要意义。

随着深度神经网络在计算机视觉、自然语言处理等领域的应用日益广泛,在使用深度学习模型解决各种实际问题过程中,如何知道模型提取到的特征是否是有效、鲁棒的成为了一个重要的问题。从这个角度看,对比学习作为自监督学习的一项重要研究方向,对于提高模型的泛化能力、充分利用未标记的数据做出了积极贡献。

本文将对比学习的基本原理进行详细介绍,并探讨其在图像分类领域的应用。

二、数据集介绍

STL-10数据集

STL-10数据集是一个用于开发无监督特征学习、深度学习、自监督学习算法的图像识别数据集。其灵感来自CIFAR-10数据集,每个类别的标注图像数量相比CIFAR-10中的要少,但提供了大量的无标注图像来做无监督预训练,其主要的挑战是利用无标签图像来构建先验知识。

STL-10的图像来自ImageNet,共有113000张96 x 96分辨率的RGB图像,其中训练集为5000张,测试集为8000张,其余100000张均为无标签图像。

STL-10数据集共有10类,其索引与标签如表1所示:

表1 STL-10数据集标签索引与标签名

Index

0

1

2

3

4

5

6

7

8

9

Name

airplane

bird

car

cat

deer

dog

horse

monkey

ship

truck

图1展示了STL-10数据集中的数据实例。

自监督对比学习框架SimCLR原理_第1张图片

图1 STL-10数据集中的数据实例

三、无监督图像表征对比学习

对比学习是从已知的数据对中学习相似度的一种方法。其目的是通过在相似性计算中考虑不同样本之间的相对位置,来学习样本之间的相似性和差异性。

对于对比学习而言,通常使用一个映射函数将样本映射到一个低维的特征空间内,在该特征空间内更好地刻画不同样本之间的关系。通过对这个映射函数进行学习,我们可以获得更加鲁棒的相似度度量方式。

SimCLR

SimCLR是一种基于对比学习的无监督学习方法,是对比学习中一个较为经典的模型框架。其主要思想是通过最大化同一图像的不同数据增强版本之间的相似性,来学习有用的特征表示。

SimCLR中使用了一组称为数据扩充策略的操作,将原始图像变换成不同的图像副本。然后将这些图像对作为输入样本输入到一个编码器中(通常是ResNet),并计算编码器输出在高维特征空间中的相似性表示。

SimCLR算法基本原理

图2显示了一个简单的视觉对比学习框架。该框架包括两个主要部分:数据增强与特征学习。

自监督对比学习框架SimCLR原理_第2张图片

图2 SimCLR基本框架

在数据增强中,每个数据样本会被随机选取两种不同的数据增强方法,比如是旋转、随机裁剪和变形等,然后得到两张相关的图像视图。

在特征学习中,使用了一个编码器网络f(·)和一个投影头g(·)。将两个图像视图输入编码器网络产生两个输出向量,并通过投影头转化到同一维度空间,然后根据它们的相似性计算对比损失,让两个向量更加接近,从而获得更有效表示方法。

最后,在训练完成后,去掉投影头g(·),使用编码器f(·)和表示h用于下游任务。这样,该模型可以学习到有意义的视觉特征表示,并且可以进行各种任务的迁移学习。

数据增强与正负样本匹配

对比学习属于无监督学习。无监督学习是不需要标签的,但在图像分类任务中,既然要做分类,就要知道这个图像是属于哪一个类别的,因此就需要我们主动的产。生一些标签。

对于图像分类任务而言,可以通过图像数据增强的方式产生正样本,图像数据增强策略如图3所示。对于一张狗的图像,对其进行随机裁剪、随机翻转、颜色变换、添加噪声等图像数据增强后,图像还是一张狗。此时数据增强前后的图像都是属于同一个类别的,对比学习的“标签”就是这样产生的。

自监督对比学习框架SimCLR原理_第3张图片

图3 图像数据增强策略

为了增强图像特征的多样化,对同一张图像进行两次不同的数据增强。我们将对于同一张图像的两次不同的数据增强作为一个正样本对,那么这个batch中其他的图像就都作为负样本。为了方便展示,暂时将batch size设为8,并将batch中的所有图像进行可视化显示,每个batch得到的图像如图4所示:

自监督对比学习框架SimCLR原理_第4张图片

自监督对比学习框架SimCLR原理_第5张图片

图4 batch size为8时程序中的图像增强及匹配情况

当batch size为8时,每个batch中都有16张图像,第一行为对原图进行第一次数据增强的结果,第二行为对原图进行第二次数据增强的结果。同一列的两张图像则是对于同一张图像的两次不同的数据增强。我们将这两张图像作为一个正样本对,因为他们描述的时同一种信息,其他的七个样本对则全部视为负样本。例如图像0与图像8就是一个正样本对,而相对于图像0与图像8,其余所有的图像都视为负样本。在对比学习中,我们并不过度关注图像中的“语义特征”,而是更加关注“表征信息”。

在分类任务中,所谓的“语义特征”即神经网络对于某一类别在特征空间中的抽象描述,抓住了某一类别的本质特征,这种特征容易在有监督学习中的深层神经网络中被捕获;而“表征特征”则更像神经网络对某一类别的多尺度与多视角的特征,具有更佳的泛化能力与鲁棒性。

编码器

有了正样本与负样本的匹配规则,接下来要对正负样本进行统一的处理。在常规的有监督图像分类任务中,通常将图像输入进卷积神经网络中,最后得到一个用于表示不同类别概率的向量。图5展示了一个图像二分类任务的输入与输出,卷积神经网络这里可以理解为一个映射的过程,即由图像映射为向量。

自监督对比学习框架SimCLR原理_第6张图片

图5 卷积神经网络处理二分类问题的输入与输出

从某种角度来说,卷积神经网络其实就是一个编码器,它将一张图像编码为一个向量。在对比学习的任务中,就是使用卷积神经网络作为了一个编码器。每张图像经过卷积神经网络的编码,都会变为一条长度为128(可根据需求修改长度)的特征向量。图像被编码为特征向量后,其特征被卷积神经网络极大程度的压缩并保留了下来。在理想状态下,两个正样本由同一卷积神经网络编码产生的两个特征向量,是高度相似的。

损失函数

在常规的有监督图像分类任务中,损失函数是用来衡量预测值与标签值之间差异的,然后将产生的损失回传,并不断更新神经网络中的参数,使得损失最小。

在对比学习任务中,我们希望正样本之间产生的特征向量尽可能的相似,而正样本与负样本尽可能的不相似。我们通过计算两个特征向量之间的余弦相似度(Cosine Similarity)来衡量两个特征向量的相似度。

余弦相似度是将两个向量映射到高维空间中,计算它们之间夹角的余弦值作为相似度。其计算公式如公式(1)所示:

如果两个向量的在高维空间中的方向越接近,那么它们之间的夹角就越小,其余弦值也就越接近1。反之,如果它们的方向完全相反,则余弦值为-1。如果它们之间完全随机(即不相关),则余弦值的期望值接近0。

对比学习的任务中提出了一种损失函数:InfoNCE loss。InfoNCE loss是交叉熵损失函数的一种变体,计算公式如公式(2)所示。

自监督对比学习框架SimCLR原理_第7张图片

通过分析以上的公式,我们可以发现,InfoNCE loss实际上是将两个向量的余弦相似度,带入softmax函数,再带入-log函数,最后取期望的损失函数,其目的是让xixi+ (某一样本与其匹配的正样本)之间的余弦相似度尽可能的大,而与xj (所有样本)之间的余弦相似度尽可能的小。

下面通过具体的实图来分析计算损失的过程。

首先我们通过数据增强,得到两个正样本对,分别为Pair 1猫与Pair 2象的正样本对,如图6所示。

自监督对比学习框架SimCLR原理_第8张图片

图6 获取正样本对

接下来,我们计算Softmax函数,分子为某样本与其匹配的正样本之间的余弦相似度,分母则为这一样本与所有样本之间的余弦相似度。我们的目的是让softmax函数的分子尽可能的变大,分母尽可能的变小,也就是样本正样本之间的余弦相似度尽可能的大,与负样本之间的余弦相似度尽可能的小。

自监督对比学习框架SimCLR原理_第9张图片

图7 Softmax函数的计算

因为我们希望Softmax的函数值尽可能的大,而损失函数在优化的过程中是越来越小的,因此在Softmax的基础上,取-log,使得损失函数最小化的同时,保持Softmax的函数值尽可能的大,如图8所示。

图8 在Softmax函数的基础上取-log

最后,对所有样本计算损失函数,相加后取均值,做为最终的损失,最终的损失函数如图9所示。

自监督对比学习框架SimCLR原理_第10张图片

图9 最终的损失函数

对比学习全过程

总结一下对比学习的全过程,如图10所示。

自监督对比学习框架SimCLR原理_第11张图片

图10 对比学习全过程

在获取到数据集中的某张图像时,先对其进行两次随机的数据增强,得到两张图像作为一个正样本对;同时,这个batch中的其他图像也会进行两次随机的数据增强,得到其他的正样本对,而相对于每个正样本对而言,其余所有图像均视为负样本。

接下来,将所有图像输入到编码器中,这里将CNN层与MLP层共同组成的网络结构作为一个编码器。MLP层的加入可以是向量具有更多的非线性变化,能够让特征向量在特征空间中获得更好的表示。

通过CNN层与MLP层的编码,每张图像都被编码成了一条特征向量。接下来就要通过不同的特征向量,计算其InfoNCE损失,并根据损失,在反向传播的过程中,不断更新编码器中的权重项与偏置项,使得正样本对产生的特征特征向量对之间的相似性尽可能的大。

以上过程就是通过SimCLR进行对比学习的全过程。

四、有监督的图像下游任务迁移

替换下游任务网络层

图像的上游任务一般指的是网络结构的预训练,可以得到一个具有特征提取能力的卷积神经网络或其他网络结构,但具体任务不明确,只是用来做特征的提取;而图像的下游任务,则是有着明确的任务目标,如分类、检测、分割等等。前面说到的对比学习,实际上就是图像的上游任务。

自监督任务要投入到实际的应用中,就必须要做下游任务的迁移。通过对比学习的训练,我们只是得到了一个编码器,并不能完成某项具体的下游任务。此时则需要少量的有监督数据,来完成下游任务的迁移。

具体来说,我们通过对比学习的方式训练好一个由CNN层与MLP层编码器时,通常会将编码器中的MLP层替换为满足下游任务需求的网络层。例如,在原始的编码器中,将图像编码为一条长度为128的向量,这就意味着编码器中的MLP层的输出层有128个神经元;若将这个编码器迁移到图像的10分类任务中,就要将这个编码器替换为输出层为10个神经元的MLP层。为了方便起见,我们将这个针对分类这一下游任务而替换的MLP层,称为 “分类头”。

有监督训练

分类头完成替换后,此时的分类头中的权重都是随机,我们需要再次进行训练,在编码器的基础上,利用有监督的数据,完成分类头的训练。

由于此时卷积神经网络的backbone是通过无监督的数据训练得到的,相当于神经网络已经存在了有关特征提取的先验知识,因此只需要使用少量的有监督数据,就可以使最后的全连接层找到有每个类别的分类决策边界。

五、实验

我们使用ResNet-18作为网络基本的编码器,输出特征向量的长度为128,并将其迁移到STL-10数据集图像10分类的下游任务。实验分为无监督训练与有监督训练两部分。

无监督训练

无监督训练的各项超参数如表2所示:

表2 无监督训练的各项超参数

超参数

优化器

学习率

轮数

批大小

权重衰减

特征向量

Name

Adam

0.001

100

256

0.0001

128

为了评估对比学习的效果,我们根据对比学习的匹配规则,产生了对比学习的标签,仅用于评估。我们计算了通过产生的标签计算了预测值的准确率与交叉熵损失,其变化如图11所示。

自监督对比学习框架SimCLR原理_第12张图片

图11 无监督训练过程中的准确率与损失变化

训练准确率在20轮以后就趋于收敛,最后的几轮达到了100%;交叉熵损失在100轮时仍有下降的趋势。说明无监督训练实现了非常理想的训练效果。但无监督训练难以评估是否出现了过拟合的现象。

为了更直观的展示无监督学习的训练效果,我们将batch size设为8,并将生成的图像与训练不同阶段的相似度矩阵截取出来,做了可视化的展示,来反映编码器的训练效果。训练不同阶段的相似度矩阵如图12所示:

自监督对比学习框架SimCLR原理_第13张图片

自监督对比学习框架SimCLR原理_第14张图片

 自监督对比学习框架SimCLR原理_第15张图片

自监督对比学习框架SimCLR原理_第16张图片

图12 训练不同阶段的相似度矩阵

由图可见,在训练的初始化阶段,也就是未经训练的编码器,特征向量之间的相似度是非常随机的,没有什么规律。但经过对比学习的训练后,相似度矩阵中出现了两条亮度较高的斜线,而其余的像素点的颜色都被抑制。仔细观察图像,亮度较高的两条斜线是由正样本对产生的,说明经过训练的编码器,实现了正样本对之间的特征向量尽可能的相似,而与其他的负样本尽可能的不相似,达到了想要的效果。

有监督训练

有监督训练的各项超参数如表3所示:

表3 有监督训练的各项超参数

超参数

优化器

学习率

轮数

批大小

权重衰减

损失

Name

Adam

0.001

100

256

0.0001

交叉熵损失

我们首先使用ResNet-18模型,在没有预训练权重的情况下,使用有监督的方式,对STL-10数据集进行了训练和测试,将其作为baseline。训练过程如图13所示:

自监督对比学习框架SimCLR原理_第17张图片

图13 没有预训练权重的ResNet18训练过程

在训练集上,模型很快收敛,准确率达到了100%,但由于测试集的数据量较多,模型在测试集上的准确率表现不佳,并在后期出现了测试集损失上升的现象,发生了典型的过拟合现象。

在得到通过对比学习训练的编码器之后,我们将最后的全连接层,替换为了针对STL-10数据集的十分类任务的分类头。为了验证无监督学习的特征提取效果,我们首先对ResNet-18模型的卷积层进行了冻结,即只训练分类头,而不训练卷积神经网络,以此来研究无监督提取的特征是否能够有效的迁移至有监督的下游任务中。我们将此模型称为“编码冻结ResNet-18模型”。

我们没有对超参数进行调整,直接训练了分类头,训练过程如图14所示:

自监督对比学习框架SimCLR原理_第18张图片

图14编码冻结ResNet-18模型训练过程

编码冻结ResNet-18模型在训练集上达到了70%的准确率,训练集损失也是高于baszeline的,但其在验证集上的效果缺完全优于baseline,且还存在着上升趋势。这也证明了自监督学习的有效性。

模型在训练集上表现不佳的原因是,卷积层被冻结,无法更新参数,致使输入分类层的特征向量掺杂了较大的损失。接下来,我们解除对卷积神经网络的冻结,将其称为“迁移对比学习ResNet模型”,并做训练实验,实验过程如图15所示:

自监督对比学习框架SimCLR原理_第19张图片

图15 迁移对比学习ResNet-18模型训练过程

迁移对比学习ResNet-18模型在训练集上的准确率以及损失都达到了与baseline相近的水准,证明其很好的拟合了训练集;在测试集上,其准确率达到了75%,是所有模型中最高的,其测试集损失也是三个模型中最小的,但也出现了过拟合的现象。

三种模型最后一轮的具体指标对比如表4所示:

表4 模型各指标对比

模型

训练集损失

训练集准确率

测试集损失

测试集准确率

ResNet-18(baseline)

0.0041

100.0

1.6860

50.40

冻结编码ResNet-18

0.9093

70.06

1.0023

67.48 (+17.08)

迁移对比学习ResNet-18

0.0119

100.0

0.9898

75.74 (+25.34)

为了更直观的比较三种模型的不同,我们将三个模型的准确率绘制在了同一张图上。图像如图16所示。

自监督对比学习框架SimCLR原理_第20张图片

图16 各个模型的准确率对比

六、结论

本文通过STL-10数据集,证明了对比学习的优越性。对比学习相对于传统的单一任务学习具有独特的优越性。在对比学习中,模型可以通过同时学习和比较多个任务之间的差异和共性,从而提高其泛化能力和性能。对比学习在提高模型性能、处理数据稀缺和解决领域偏移等方面具有明显的优越性。通过充分利用任务之间的关系和知识传递,对比学习能够加速新任务的学习过程,并提升模型的泛化能力,从而在各种应用场景中发挥重要作用。

时至今日,自监督学习的应用逐渐增多,并在各个领域展现出巨大的潜力和影响力。自监督学习提供了一种无监督学习的方法。相比于传统的无监督学习方法,自监督学习通过设计自动生成标签的预测任务,不需要人工标注数据,从而降低了数据获取和标注成本。这使得模型可以更好地利用大规模无标签数据,为处理实际中存在大量无标签数据的场景提供了可行的解决方案。

你可能感兴趣的:(人工智能,深度学习,机器学习,计算机视觉,cnn,神经网络,图像处理)