对比学习损失函数中超参数temperature的作用

目录

  • 背景
  • 超参数temperature的直观理解
  • 对比学习中的temperature参数理解

背景

最近在看凯明大神的对比学习MOCO时,看到infoNCE loss的公式时,对其中参数T(应该是tao,打不出来,就浅用T代替一下)有点费解,于是查阅了一些资料,记录一下自己的理解。

首先,附上infoNCE loss的公式。

L q = − l o g e x p ( q ⋅ k + / τ ) ∑ i = 0 k e x p ( q ⋅ k i / τ ) L_q = -log\frac{exp(q\cdot k_+ / \tau )}{\sum_{i=0}^{k}exp(q\cdot k_i / \tau)} Lq=logi=0kexp(qki/τ)exp(qk+/τ)
在其他地方看到这个公式的另一种写法,感觉更容易理解,也更加一般。

L ( x i ) = − l o g e x p ( s i , i / τ ) ∑ k ≠ i e x p ( s i , k / τ ) + e x p ( s i , i / τ ) L(x_i) = -log\frac{exp(s_{i,i}/ \tau )}{\sum_{k\neq i}exp(s_{i,k} / \tau) + exp(s_{i,i} / \tau)} L(xi)=logk=iexp(si,k/τ)+exp(si,i/τ)exp(si,i/τ)
简单的来说,这个公式就是熟悉的cross entropy loss(交叉熵损失)的一个变体。 s i , i s_{i,i} si,i是指当前特征与正样本间的相似度,这个相似度可以用点乘,也可以用其他方式计算。 s i , k s_{i,k} si,k是指当前特征与负样本之间的相似度。

其次,在对比学习中,每张图片相当于一个类别,对于每张图片,通过对自身数据增强后的图片为正例,其余所有图片都是负例。对比学习的目的是尽量使正例之间的相似度相近,且负例之间的相似度越低越好。换一句话说,就是要训练一个特征提取网络,使得所有图片在特征空间中的特征向量都尽可能的分开。下面是在InstDisc文中一副对比学习图片。

对比学习损失函数中超参数temperature的作用_第1张图片

感谢以下大佬的博文和论文,本文是在他们的基础上写作的
https://blog.csdn.net/qq_36560894/article/details/114874268
(CVPR2021)理解对比损失的性质以及温度系数的作用:arxiv
对上面论文的理解:知乎

超参数temperature的直观理解

由于infoNCE loss是交叉熵损失的一个变体,为了更加直观的理解,我们先在交叉熵损失中加入temperature,看一下有什么样的效果。

假设一个三分类的问题,预测图片是猫,狗还是猪。特征提取到最后一层,输出为[1 ,2, 3],假设预测正确,结果确实是猪,那么交叉熵应当这样计算。

import torch
import torch.nn as nn
import torch.nn.functional as F

criterion = nn.CrossEntropyLoss()
x = torch.Tensor([[1, 2, 3]])
y = torch.Tensor([2]).type(torch.long)

# 当temperature=1时
t = 1
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[0.0900, 0.2447, 0.6652]])
# loss:tensor(0.4076)

# 当temperature=0.5时
t = 0.5
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[0.0159, 0.1173, 0.8668]])
# loss:tensor(0.1429)

# 当temperature=0.1时
t = 0.1
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[2.0611e-09, 4.5398e-05, 9.9995e-01]])
# loss:tensor(4.5418e-05)

不难看出,当分类结果正确时,当temperature越小时,softmax输出各类别的分数差别越大,loss越小。

可以在尝试当分类结果错误,比如当正确分类结果为0时,即y=torch.Tensor([0]).type(torch.long)时,有着这样的规律:当分类结果错误时,当temperature越小时,softmax输出各类别的分数差别越大,loss越大。

对比学习中的temperature参数理解

讲完了对比学习的背景和超参数T的直观理解,我们有如下结论:

(1)对比学习的目的是训练一个特征提取网络,使得所有特征向量在特征空间中尽可能的远离。

(2)当分类结果错误时,当temperature越小时,softmax输出各类别的分数差别越大,loss越大。

我们下面开始讲在对比学习loss中加入temperature参数解决的核心问题:困难负样本问题。

困难负样本,就是一张图像经过特征提取网络后,发现自己相较于自身数据增强后的图片特征,更相似于其他图片提取出的特征。但是,相似度并没有差很多。这样的样本我们就叫他困难负样本。

如果没有引入temperature参数,当有困难负样本过来时,loss相对较小,对参数的惩罚也就较小。由于我们希望所有特征向量尽量远离,因此,必须对所有错误分类的样本都加大惩罚,所以,要加入一个小于1的temperature参数,来放大对于困难负样本的惩罚。

讲到这,对比学习中的temperature参数其实就已经讲的差不多了,下面再略微提一下Uniformity-Tolerance Dilemma,也就是均匀性-容忍性困境。

这里又要说到对比学习的目标,是通过大规模自监督学习去训练一个能够很好提取特征的特征提取网络,说到底,就是一个代理任务,这个使所有图片特征尽量分开的任务本身是没有任何意义的,只是用来去训练特征提取网络。训练好这个特征提取网络后,就可以加上不同的检测头来执行一系列的下游任务,如检测、分割等。

说回temperature参数,考虑一下出现困难负样本的原因,有可能是因为两张图片确实非常相似,通常是两张图片有着相同的前景,让算法产生了混淆。也就是说,其实网络已经学到了一定的语义特征,这对下游任务是有帮助的,强行将两张非常相似图片提取出的特征相互远离,有可能打破这种语义信息,导致在执行下游任务时,效果不升反降。

因此,调temperature参数是一个很讲究的事情,太高不能很好的训练特征提取网络,太低又会打破模型学到的语义信息,损害下游任务的准确度。
这种语义信息,导致在执行下游任务时,效果不升反降。

因此,调temperature参数是一个很讲究的事情,太高不能很好的训练特征提取网络,太低又会打破模型学到的语义信息,损害下游任务的准确度。

你可能感兴趣的:(自监督学习笔记,深度学习,人工智能,pytorch,迁移学习)