【损失函数】Contrastive Loss, Triplet Loss and Center Loss

文章目录

  • 1. 损失函数
    • Contrastive Loss [1]:
    • Triplet Loss [2]:
    • Center Loss[3]:
  • 2. 问题引入:
  • 3. Contrastive Loss:对比损失
    • 3.1 本质
    • 3.2 定义
    • 3.3 含义
  • 4. Triplet Loss:三元组损失
    • 4.1本质
    • 4.2 定义
    • 4.3 目标
    • 4.4 公式
    • 4.5 进阶:FaceNet
  • 5. Center Loss:
    • 5.1 定义
    • 5.2 公式

1. 损失函数

Contrastive Loss [1]:

来源:Yann LeCun论文:Dimensionality Reduction by Learning an Invariant Mapping
目的:增大分类器的类间差异

Triplet Loss [2]:

来源: FaceNet: A Unified Embedding for Face Recognition and Clustering
目的:对Contrastive Loss的改进

Center Loss[3]:

来源:A Discriminative Feature Learning Approach for Deep Face Recognition
目的:解决三元组数据激增导致的网络收敛缓慢

2. 问题引入:

假设我们有2张人脸图片,需要判断这两张人脸图片是不是对应于同一个人,一般情况下如何解决?
一种简单直接的思路就是提取图片的特征,然后对比两个特征向量的相似度。但这种简单做法存在一个明显问题,就是CNN提取的特征“类间”区分性真的有那么好吗?
用SoftMas损失函数训练出的分类模型在Minst测试集上就表现出“类间”区分边缘不大的问题,使得遭受对抗样本攻击的时候很容易分类失败。况且人脸识别需要考虑到样本的类别以及数量都非常多,这无疑使得直接用特征向量来对比更加的困难。

3. Contrastive Loss:对比损失

3.1 本质

原本主要用于降维,即本来相似的样本经过降维(特征提取)后,两个样本仍旧相似;而原本不想死的样本,经过降维后,两个样本人就不想死。

针对上面问题,提出了孪生网络,结构如下:
【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第1张图片

3.2 定义

Contrastive Loss 可以有效的处理孪生网络中的成对数据关系。
【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第2张图片
在这里插入图片描述
● W是网络权重,X是样本,Y是成对标签。
● 如果X1与X2这对样本属于同一类则Y=0,否则Y=1。
● Dw是X1与X2在潜变量空间的欧几里得距离。
● 当Y=0,调整参数最小化X1与X2之间的距离
● 当Y=1,当X1与X2间距大于m,则不做优化。当X1与X2间距大于m,则增大两者距离到m。

其实有些像Modified Huber Loss(结合了MSE与Hinge Loss)

下图展示了损失函数L和样本特征的欧氏距离的关系,红线表示相似样本的损失值,而蓝色则是不相似样本的损失值。
【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第3张图片
梯度更新
使用随机梯度下降来更新w,以得到较小loss,更好表达成对样本的匹配程度。计算loss梯度的公式如下:

  • Y=0, 两个样本相似时,梯度为
    在这里插入图片描述
  • Y=1, 两个样本不相似时,梯度为
    在这里插入图片描述

3.3 含义

可解释为:弹簧在收缩到一定程度的时候因为受到斥力的原因会恢复到原始长度。
弹簧模型公式:F=

  • a 相似是吸引力,蓝点与只吸引弹簧连接到相似的点
  • b 与相似对关联的损失函数及梯度
  • c 不相似是弹射力,蓝点仅与半径为m的圆内的非相似点用m-repulse-only弹簧连接
  • d 与非相似对关联的损失函数及梯度
  • e 一个点被不同方向的其他点拉动,形成平衡的情况
    【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第4张图片

4. Triplet Loss:三元组损失

4.1本质

能够更好地对细节进行建模,相当于加入了两个输入差异性差异的度量,学习到输入的更好表示。

4.2 定义

  • 最小化 锚点 和具有相同身份的正样本之间的距离
  • 最小化 锚点 和 具有不同身份的负样本之间的距离

4.3 目标

  • 相同标签的特征 在空间位置上尽量靠近
  • 不同标签的特征 在哦空间位置上尽量远离
  • 为了不让样本的特征聚合到一个非常小的空间中,要求同一类:2正例+1个负例,且负例应比正例的距离>=margin.
    【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第5张图片

4.4 公式

【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第6张图片
【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第7张图片

其中, α \alpha α就是 m a r g i n margin margin T T T就是样本数为 N N N的数据集的三元组。
针对三个样本的梯度公式为:
在这里插入图片描述
在这里插入图片描述在这里插入图片描述

4.5 进阶:FaceNet

描述

  • 将三元组重新描述为 ( a , p , n ) (a,p,n) (a,p,n),其中a:anchor;p:positive;n:negative。
  • 因此最小化三元组损失就是让锚点a和正样本p的距离趋于0,即 d ( a , p ) − − > 0 d(a,p)-->0 d(a,p)>0
  • 让锚点a与负样本n的距离大于 d ( a , p ) + m a r g i n d(a,p)+margin d(a,p)+margin,即 d ( a , n ) > d ( a , p ) + m a r g i n d(a,n)>d(a,p)+margin d(a,n)>d(a,p)+margin
  • 总距离:L = max(d(a,p) - d(a,n)+ margin, 0)

目标

  • a,p距离近
  • a,n距离远

定义

  • easy_triplets:代表 L = 0 L=0 L=0,
  • hard_triplets:代表 d ( a , n ) d(a,n) d(a,n)
  • semi-hard triplets: 代表 d ( a , p ) + m a r g i n d(a,p)+margin d(a,p)+margin
    【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第8张图片
    训练策略
  • 随机选取 semi-hard triplets进行训练

5. Center Loss:

5.1 定义

为每一个类别提供一个类中心,最小化min-batch中每个样本与该类中心距离,即缩小类内距离。

5.2 公式

【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第9张图片
c y i c_{y_i} cyi 就是第 y i y_i yi个类别的特征中心, x i x_i xi表示全连接层之前的特征, m m m表示mini-batch的大小。

梯度更新

  • 如果条件满足 δ \delta δ(条件)=1, 否则 δ \delta δ(条件)=0,即只有当 y i y_i yi c j c_j cj的类别 i = j i=j i=j的时候才更新 c j c_j cj
  • center可使用xavier进行初始化,然后在每个mini-batch迭代后在当前类别中更新一次。每次迭代过程中,只对对应类别的特征取平均计算,相当于让 c j c_j cj x x x的平均移动。
  • 为了避免少样本类别造成的较大干扰,采用一个引子 α \alpha α来控制类别中心的学习率
  • 用softmax Loss +Center loss联合训练。其中softmax让不同类的深度特征分开,center loss可以将同一类的特征吸引到类中心。两者结合不仅可以扩大不同类的特征区别,还可以减小同一类的特征的区别。
  • λ \lambda λ越大Center loss比重越大,类内越聚合,判别能力越大。
    【损失函数】Contrastive Loss, Triplet Loss and Center Loss_第10张图片
    在这里插入图片描述

参考:
[1] http://www.cs.toronto.edu/~hinton/csc2535/readings/hadsell-chopra-lecun-06-1.pdf
[2] https://arxiv.org/abs/1503.03832
[3] http://ydwen.github.io/papers/WenECCV16.pdf
[4]https://zhuanlan.zhihu.com/p/103278453
[5] https://zhuanlan.zhihu.com/p/76515370
[6] https://kevinmusgrave.github.io/pytorch-metric-learning/losses/

你可能感兴趣的:(知识普及,机器学习,tcp/ip,深度学习,机器学习)