损失函数InfoNCE loss和cross entropy loss以及温度系数

还是基础知识的搬运哦

(1)对比学习常用的损失函数InfoNCE loss和cross entropy loss是否有联系?

(2)对比损失InfoNCE loss中有一个温度系数,其作用是什么?温度系数的设置对效果如何产生影响?

个人认为,这两个问题可以作为对比学习相关项目面试的考点,本文我们就一起盘一盘这两个问题。

1. InfoNCE loss公式

对比学习损失函数有多种,其中比较常用的一种是InfoNCE loss,InfoNCE loss其实跟交叉熵损失有着千丝万缕的关系,下面我们借用恺明大佬在他的论文MoCo里定义的InfoNCE loss公式来说明。论文MoCo提出,我们可以把对比学习看成是一个字典查询的任务,即训练一个编码器从而去做字典查询的任务。假设已经有一个编码好的query  (一个特征), 以及一系列编码好的样本 , 那么  可以看作是字典里的key。假设字典里只有一个  即  (称 为  positive) 是跟  是匹配的,那么  和  +就互为正样本对, 其余的  为  的负样本。一旦定义 好了正负样本对, 就需要一个对比学习的损失函数来指导模型来进行学习。这个损失函数需要满足 这些要求, 即当query  和唯一的正样本  相似, 并且和其他所有负样本key都不相似的时候, 这 个loss的值应该比较低。反之, 如果  和  不相似, 或者  和其他负样本的key相似了, 那么loss就 应该大, 从而惩罚模型, 促使模型进行参数更新。

损失函数InfoNCE loss和cross entropy loss以及温度系数_第1张图片

2. InfoNCE loss和交叉熵损失有什么关系?

我们先从softmax说起,下面是softmax公式:

损失函数InfoNCE loss和cross entropy loss以及温度系数_第2张图片

上式中的  在有监督学习里指的是这个数据集一共有多少类别, 比如CV的ImageNet数据集有 1000 类, k就是1000。

对于对比学习来说,理论上也是可以用上式去计算loss,但是实际上是行不通的。为什么呢?

还是拿CV领域的ImageNet数据集来举例,该数据集一共有128万张图片,我们使用数据增强手段(例如,随机裁剪、随机颜色失真、随机高斯模糊)来产生对比学习正样本对,每张图片就是单独一类,那k就是128万类,而不是1000类了,有多少张图就有多少类。但是softmax操作在如此多类别上进行计算是非常耗时的,再加上有指数运算的操作,当向量的维度是几百万的时候,计算复杂度是相当高的。所以对比学习用上式去计算loss是行不通的。

怎么办呢?NCE loss可以解决这个问题。

NCE(noise contrastive estimation)核心思想是将多分类问题转化成二分类问题,一个类是数据类别 data sample,另一个类是噪声类别 noisy sample,通过学习数据样本和噪声样本之间的区别,将数据样本去和噪声样本做对比,也就是“噪声对比(noise contrastive)”,从而发现数据中的一些特性。但是,如果把整个数据集剩下的数据都当作负样本(即噪声样本),虽然解决了类别多的问题,计算复杂度还是没有降下来,解决办法就是做负样本采样来计算loss,这就是estimation的含义,也就是说它只是估计和近似。一般来说,负样本选取的越多,就越接近整个数据集,效果自然会更好。

NCE loss常用在NLP模型中,公式如下:

损失函数InfoNCE loss和cross entropy loss以及温度系数_第3张图片

上述公式细节详见:NCE loss(https://arxiv.org/pdf/1410.8251.pdf)

有了NCE loss,为什么还要用Info NCE loss呢?

Info NCE loss是NCE的一个简单变体,它认为如果你只把问题看作是一个二分类,只有数据样本和噪声样本的话,可能对模型学习不友好,因为很多噪声样本可能本就不是一个类,因此还是把它看成一个多分类问题比较合理(但这里的多分类kk指代的是负采样之后负样本的数量,下面会解释)。于是就有了InfoNCE loss,公式如下:

 

上式中,  是模型出来的logits, 相当于上文  oftmax公式中的  是一个温度超参数, 是个标 量, 假设我们忽略 , 那么infoNCE loss其实就是cross entropy loss。唯一的区别是, 在cross entropy loss里,  指代的是数据集里类别的数量, 而在对比学习InfoNCE loss里, 这个k指的是负样本的数量。上式分母中的sum是在 1 个正样本和  个负样本上做的, 从0到 , 所以共  个样本, 也就是字典里所有的key。恺明大佬在MoCo里提到, InfoNCE loss其实就是一个cross entropy loss, 做的是一个  类的分类任务, 目的就是想把  这个图片分到  这个类。

另外,我们看下图中MoCo的伪代码,MoCo这个loss的实现就是基于cross entropy loss。

损失函数InfoNCE loss和cross entropy loss以及温度系数_第4张图片

3. 温度系数的作用

温度系数  虽然只是一个超参数, 但它的设置是非常讲究的, 直接影响了模型的效果。上式Info NCE loss中的  相当于是logits, 温度系数可以用来控制logits的分布形状。对于既定的logits分 布的形状, 当  值变大, 则  就变小,  则会使得原来logits分布里的数值都变小, 且经过指数运算之后, 就变得更小了, 导致原来的logits分布变得更平滑。相反, 如果  取得值小,  就 变大, 原来的logits分布里的数值就相应的变大, 经过指数运算之后, 就变得更大, 使得这个分布变得更集中, 更peak。

如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

总之,温度系数的作用就是它控制了模型对负样本的区分度。

  whaosoft aiot http://143ai.com  

 

你可能感兴趣的:(人工智能,人工智能)