寻找更好的分类模型loss

寻找更好的loss

  • 1.CE loss并不完美
  • 2.可能更好的loss函数
    • 2.1 CC-LOSS
    • 2.2 Center-LOSS
  • 参考文献

1.CE loss并不完美

最常用于深度学习分类模型的损失函数可以说就是CE(交叉熵) loss了。正如CC-LOSS paper中所述,该loss更关注各类是否separated,而非不同类之间距离远/相同类聚类更近这样的模式识别分类基本要求。这样可能导致过拟合或者泛化能力弱等缺点。

CE-Loss has two main issues that limit the performance of a CNN model for classification.

  • Firstly, the high level features extracted by CNNs with the CE-Loss are only separable with each other but not discriminative enough, which can easily lead to over-fitting of the model and thus weak generalization performance.
  • Secondly, the parameters of the deep CNN model are trained jointly with all the classes, which makes the high level features extracted by CNNs to be confused with each other and increases the difficulty of optimization.

2.可能更好的loss函数

2.1 CC-LOSS

寻找更好的分类模型loss_第1张图片
其损失函数为交叉熵损失,再额外加上一部分,即类内距离之和/类间距离之和:
寻找更好的分类模型loss_第2张图片
寻找更好的分类模型loss_第3张图片
其在MNIST数据集上表现如下,感觉不够惊喜,与CE-LOSS相差不大,并没有达到所提出的类内compact,类间dispense这样的程度。
寻找更好的分类模型loss_第4张图片

2.2 Center-LOSS

paper 见A Discriminative Feature Learning Approachfor Deep Face Recognition。整体来说它的策略也很简单,即除了CE,再额外一部分损失,该损失是为了使得某一层学习到的特征及学习到center之间距离尽可能接近。
寻找更好的分类模型loss_第5张图片

寻找更好的分类模型loss_第6张图片
它的pytorch实现可参考https://github.com/KaiyangZhou/pytorch-center-loss,git中也展示了mnist数据集上,使用center-loss前后的效果。

参考文献

[1] CC-LOSS: CHANNEL CORRELATION LOSS FOR IMAGE CLASSIFICATION
[2] Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016
[3] https://github.com/KaiyangZhou/pytorch-center-loss

你可能感兴趣的:(深度学习,分类,人工智能)