度量学习和对比学习的思想是一样的,都是去拉近相似的样本,推开不相似的样本。但是对比学习是无监督或者自监督学习方法,而度量学习一般为有监督学习方法。而且对比学习在 loss 设计时,为单正例多负例的形式,因为是无监督,数据是充足的,也就可以找到无穷的负例,但如何构造有效正例才是重点。
而度量学习多为二元组或三元组的形式,如常见的 Triplet 形式(anchor,positive,negative),Hard Negative 的挖掘对最终效果有较大的影响
infoNCE loss 全称 info Noise Contrastive Estimation loss,对于一个 batch 中的样本 i,它的 loss 为:
L i = − log ( e S ( z i , z i + ) / τ / ∑ j = 0 K e S ( z i , z j ) / τ ) L_{i}=-\log \left(e^{S\left(z_{i}, z_{i}^{+}\right) / \tau} / \sum_{j=0}^{K} e^{S\left(z_{i}, z_{j}\right) / \tau}\right) Li=−log(eS(zi,zi+)/τ/j=0∑KeS(zi,zj)/τ)
要注意的是,log 里面的分母叠加项是包括了分子项的。分子是正例对的相似度,分母是正例对+所有负例对的相似度,最小化 infoNCE loss,就是去最大化分子的同时最小化分母,也就是最大化正例对的相似度,最小化负例对的相似度。
Cross Entropy loss,在输入 p 是 softmax 的输出时
L = − ∑ j = 0 K y i log ( e z i / ∑ j = 0 K e z j ) L=-\sum_{j=0}^{K} y_{i} \log \left(e^{z_{i}} / \sum_{j=0}^{K} \mathrm{e}^{z_{j}}\right) L=−j=0∑Kyilog(ezi/j=0∑Kezj)
在分类场景下,真实标签 y 一般为 one-hot 的形式,因此,CE loss 可以简化成(i 位置对应标签 1):
L = − log ( e z i / ∑ j = 0 K e z j ) L=-\log \left(e^{z_{i}} / \sum_{j=0}^{K} \mathrm{e}^{z_{j}}\right) L=−log(ezi/j=0∑Kezj)
看的出来,info NCE loss 和在一定条件下简化后的 CE loss 是非常相似的,但有一个区别要注意的是:
infoNCE loss 中的 K 是 batch 的大小,是可变的,是第 i 个样本要和 batch 中的每个样本计算相似度,而 batch 里的每一个样本都会如此计算,因此上面公式只是样本 i 的 loss。
CE loss 中的 K 是分类类别数的大小,任务确定时是不变的,i 位置对应标签为 1 的位置。不过实际上,infoNCE loss 就是直接可以用 CE loss 去计算的
def simcse_sup_loss(y_pred, t=0.05):
"""有监督的损失函数
y_pred (tensor): bert的输出, [batch_size * 3, 768]
"""
# 得到y_pred对应的label, 每第三句没有label, 跳过, label= [1, 0, 4, 3, ...]
y_true = torch.arange(y_pred.shape[0], device=DEVICE)
# [ 0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15
use_row = torch.where((y_true + 1) % 3 != 0)[0]
# [ 1, 0, 4, 3, 7, 6, 10, 9, 13, 12,
y_true = (use_row - use_row % 3 * 2) + 1
# batch内两两计算相似度, 得到相似度矩阵(对角矩阵)
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
# 将相似度矩阵对角线置为很小的值, 消除自身的影响
sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12
# 选取有效(use_row中对应的)的行
sim = torch.index_select(sim, 0, use_row)
# 相似度矩阵除以温度系数
sim = sim / t
# 计算相似度矩阵与y_true的交叉熵损失
loss = F.cross_entropy(sim, y_true)
return torch.mean(loss)
y_true = torch.arange(y_pred.shape[0], device=DEVICE)
use_row = torch.where((y_true + 1) % 3 != 0)[0]
y_true = (use_row - use_row % 3 * 2) + 1
上面三行时为了找到y_pred对应的label,因为输入(y_pred)是[ex1, ex1_pos, ex1_neg, ex2, ex2_pos, ex2_neg, ...]
这种形式,可以看到,ex1与ex1_pos相似,因此ex1对应的label=1(index),同理ex1_pos的label=0,所以这样的话y_pred对应的label为[1,0,4,3, ...]
L ( r 0 , r 1 , y ) = y ∥ r 0 − r 1 ∥ + ( 1 − y ) max ( 0 , m − ∥ r 0 − r 1 ∥ ) L\left(r_{0}, r_{1}, y\right)=y\left\|r_{0}-r_{1}\right\|+(1-y) \max \left(0, m-\left\|r_{0}-r_{1}\right\|\right) L(r0,r1,y)=y∥r0−r1∥+(1−y)max(0,m−∥r0−r1∥)
上面 r 0 r_0 r0 和 r 1 r_1 r1是样本的表征, y y y 为 0 时表示负样本对, y y y 为 1 时表示正样本对,距离用欧拉距离来表示
对于正样本对,只有当网络产生的两个元素的表征没有距离时,损失才是0,损失会随着距离的增加而增加
对于负样本对,当两个元素的表征的距离超过边距 m m m 时,损失才是0。然而当距离小于 m m m 时,loss 为正值,此时网络参数会被更新,以调整这些元素的表达,当 负样本对的距离为 0 时,loss 达到最大值 m m m。
边距m的作用是,当负样本对产生的表征距离足够远时,就不会把精力浪费在扩大这个距离上,所以进一步训练可以集中在更难的样本上。
L ( r a , r p , r n ) = max ( 0 , m + d ( r a , r p ) − d ( r a , r n ) ) L\left(r_{a}, r_{p}, r_{n}\right)=\max \left(0, m+d\left(r_{a}, r_{p}\right)-d\left(r_{a}, r_{n}\right)\right) L(ra,rp,rn)=max(0,m+d(ra,rp)−d(ra,rn))
使用 triplet 三元组的而不是二元组来训练,模型的表现更好。Triplets 三元组由锚样本 r a r_a ra,正样本 r p r_p rp,和负样本 r n r_n rn组成。
模型的目标是锚样本和负样本表达的距离 d ( r a , r n ) d(r_a, r_n) d(ra,rn)要比锚样本和正样本表达的距离 d ( r a , r p ) d(r_a, r_p) d(ra,rp)大一个边距 m m m
todo…
细节满满!理解对比学习和SimCSE,就看这6个知识点
理解 Ranking Loss,Contrastive Loss,Margin Loss,Triplet Loss,Hinge Loss 等易混淆的概念