对比学习和度量学习loss的理解

度量学习和对比学习的思想是一样的,都是去拉近相似的样本,推开不相似的样本。但是对比学习是无监督或者自监督学习方法,而度量学习一般为有监督学习方法。而且对比学习在 loss 设计时,为单正例多负例的形式,因为是无监督,数据是充足的,也就可以找到无穷的负例,但如何构造有效正例才是重点。

而度量学习多为二元组或三元组的形式,如常见的 Triplet 形式(anchor,positive,negative),Hard Negative 的挖掘对最终效果有较大的影响

1. infoNCE loss(对比学习)

infoNCE loss 全称 info Noise Contrastive Estimation loss,对于一个 batch 中的样本 i,它的 loss 为:

L i = − log ⁡ ( e S ( z i , z i + ) / τ / ∑ j = 0 K e S ( z i , z j ) / τ ) L_{i}=-\log \left(e^{S\left(z_{i}, z_{i}^{+}\right) / \tau} / \sum_{j=0}^{K} e^{S\left(z_{i}, z_{j}\right) / \tau}\right) Li=log(eS(zi,zi+)/τ/j=0KeS(zi,zj)/τ)

要注意的是,log 里面的分母叠加项是包括了分子项的。分子是正例对的相似度,分母是正例对+所有负例对的相似度,最小化 infoNCE loss,就是去最大化分子的同时最小化分母,也就是最大化正例对的相似度,最小化负例对的相似度。

1.1 infoNCE和CE的比较

Cross Entropy loss,在输入 p 是 softmax 的输出时

L = − ∑ j = 0 K y i log ⁡ ( e z i / ∑ j = 0 K e z j ) L=-\sum_{j=0}^{K} y_{i} \log \left(e^{z_{i}} / \sum_{j=0}^{K} \mathrm{e}^{z_{j}}\right) L=j=0Kyilog(ezi/j=0Kezj)

在分类场景下,真实标签 y 一般为 one-hot 的形式,因此,CE loss 可以简化成(i 位置对应标签 1):

L = − log ⁡ ( e z i / ∑ j = 0 K e z j ) L=-\log \left(e^{z_{i}} / \sum_{j=0}^{K} \mathrm{e}^{z_{j}}\right) L=log(ezi/j=0Kezj)

看的出来,info NCE loss 和在一定条件下简化后的 CE loss 是非常相似的,但有一个区别要注意的是:

infoNCE loss 中的 K 是 batch 的大小,是可变的,是第 i 个样本要和 batch 中的每个样本计算相似度,而 batch 里的每一个样本都会如此计算,因此上面公式只是样本 i 的 loss。

CE loss 中的 K 是分类类别数的大小,任务确定时是不变的,i 位置对应标签为 1 的位置。不过实际上,infoNCE loss 就是直接可以用 CE loss 去计算的

1.2 info NCE的实现

def simcse_sup_loss(y_pred, t=0.05):
    """有监督的损失函数
    y_pred (tensor): bert的输出, [batch_size * 3, 768]
    """
    # 得到y_pred对应的label, 每第三句没有label, 跳过, label= [1, 0, 4, 3, ...]
    y_true = torch.arange(y_pred.shape[0], device=DEVICE)
    # [ 0,  1,  3,  4,  6,  7,  9, 10, 12, 13, 15
    use_row = torch.where((y_true + 1) % 3 != 0)[0]
    # [ 1,  0,  4,  3,  7,  6, 10,  9, 13, 12,
    y_true = (use_row - use_row % 3 * 2) + 1
    # batch内两两计算相似度, 得到相似度矩阵(对角矩阵)
    sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
    # 将相似度矩阵对角线置为很小的值, 消除自身的影响
    sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12
    # 选取有效(use_row中对应的)的行
    sim = torch.index_select(sim, 0, use_row)
    # 相似度矩阵除以温度系数
    sim = sim / t
    # 计算相似度矩阵与y_true的交叉熵损失
    loss = F.cross_entropy(sim, y_true)
    return torch.mean(loss)
    y_true = torch.arange(y_pred.shape[0], device=DEVICE)
    use_row = torch.where((y_true + 1) % 3 != 0)[0]
    y_true = (use_row - use_row % 3 * 2) + 1

上面三行时为了找到y_pred对应的label,因为输入(y_pred)是[ex1, ex1_pos, ex1_neg, ex2, ex2_pos, ex2_neg, ...]这种形式,可以看到,ex1与ex1_pos相似,因此ex1对应的label=1(index),同理ex1_pos的label=0,所以这样的话y_pred对应的label为[1,0,4,3, ...]

2. Pairwise Ranking Loss(度量学习)

L ( r 0 , r 1 , y ) = y ∥ r 0 − r 1 ∥ + ( 1 − y ) max ⁡ ( 0 , m − ∥ r 0 − r 1 ∥ ) L\left(r_{0}, r_{1}, y\right)=y\left\|r_{0}-r_{1}\right\|+(1-y) \max \left(0, m-\left\|r_{0}-r_{1}\right\|\right) L(r0,r1,y)=yr0r1+(1y)max(0,mr0r1)

上面 r 0 r_0 r0 r 1 r_1 r1是样本的表征, y y y 为 0 时表示负样本对, y y y 为 1 时表示正样本对,距离用欧拉距离来表示

对于正样本对,只有当网络产生的两个元素的表征没有距离时,损失才是0,损失会随着距离的增加而增加

对于负样本对,当两个元素的表征的距离超过边距 m m m 时,损失才是0。然而当距离小于 m m m 时,loss 为正值,此时网络参数会被更新,以调整这些元素的表达,当 负样本对的距离为 0 时,loss 达到最大值 m m m

边距m的作用是,当负样本对产生的表征距离足够远时,就不会把精力浪费在扩大这个距离上,所以进一步训练可以集中在更难的样本上。

3. Triplet Ranking Loss(度量学习)

L ( r a , r p , r n ) = max ⁡ ( 0 , m + d ( r a , r p ) − d ( r a , r n ) ) L\left(r_{a}, r_{p}, r_{n}\right)=\max \left(0, m+d\left(r_{a}, r_{p}\right)-d\left(r_{a}, r_{n}\right)\right) L(ra,rp,rn)=max(0,m+d(ra,rp)d(ra,rn))

使用 triplet 三元组的而不是二元组来训练,模型的表现更好。Triplets 三元组由锚样本 r a r_a ra,正样本 r p r_p rp,和负样本 r n r_n rn组成。

模型的目标是锚样本和负样本表达的距离 d ( r a , r n ) d(r_a, r_n) d(ra,rn)要比锚样本和正样本表达的距离 d ( r a , r p ) d(r_a, r_p) d(ra,rp)大一个边距 m m m

  • Easy Triplets: d ( r a , r n ) > d ( r a , r p ) + m d\left(r_{a}, r_{n}\right)>d\left(r_{a}, r_{p}\right)+m d(ra,rn)>d(ra,rp)+m,相对于正样本和锚样本之间的距离,负样本和锚样本的距离已经足够大了,此时 loss 为 0,网络参数无需更新
  • Hard Triplets: d ( r a , r n ) < d ( r a , r p ) d\left(r_{a}, r_{n}\right)d(ra,rn)<d(ra,rp)。负样本和锚样本的距离比正样本和锚样本之间的距离还近,此时 loss 为正,且比 m 大
  • Semi-Hard Triplets: d ( r a , r p ) < d ( r a , r n ) < d ( r a , r p ) + m d\left(r_{a}, r_{p}\right)d(ra,rp)<d(ra,rn)<d(ra,rp)+m。锚样本和负样本之间的距离比和正样本大,但不超过边距 m m m,所以 loss 依然为正(但小于 m)

4. NCE Loss

todo…

参考

细节满满!理解对比学习和SimCSE,就看这6个知识点
理解 Ranking Loss,Contrastive Loss,Margin Loss,Triplet Loss,Hinge Loss 等易混淆的概念

你可能感兴趣的:(自然语言处理,python,nlp)