深入理解: 为什么MSE Loss不适合处理分类任务?

任务场景

假设当前任务为猫狗二分类任务,猫的label为1, one-hot编码为[0, 1],狗的label是0,one-hot编码为[1, 0];

假设选取模型的最后输出维度为(N, 2), 其中N为Batch size,2为num_classes。

为什么回归任务的MSE Loss不适合处理分类任务?

如果我们选择MSE Loss作为猫狗二分类任务的损失函数,比如某个样本类别为猫,label为[0, 1], 模型的输出为[0.48, 0.52]。

那么MSE Loss所做的就是引导模型在处理这个样本时,模型输出的第一个值越接近0越好,模型输出的第二个值越接近1越好;但我们有必要让模型的输出精确到0/1吗?

分析

  1. 分类任务常用的评估指标是准确率,并不是回归任务常用的RMSE、MSE等指标;因此,对于分类任务,我们更在乎的是分类的准确率==>我们希望损失函数奖励正确分类,惩罚错误分类,而正确分类不一定非得让模型的输出精确到0/1
  2. 我们实际上是利用argmax函数从模型的输出得到分类结果,argmax函数的功能简单来说就是返回序列最大值对应的索引,例如argmax([0.48, 0.52])的返回值为1, 也就是模型预测这个样本是猫==>实际上只要模型的输出的第二个值>第一个值,就已经分类正确了。既然分类正确,对于这个样本的损失应该已经趋于平稳/收敛才比较合理,显然MSE Loss远远没有收敛(因为输出[0.48, 0.52]对于目标[0,1]还差得很远)
  3. 综合来看,MSE Loss并不适合处理分类任务。那么有没有一种Loss只要输出序列最大值的索引和标签能对上,就奖励(损失很低),否则就惩罚(损失很高)呢?——负对数似然(Negative Log Likelihood, NLL)损失函数

负对数似然(NLL)损失函数

公式如下:
深入理解: 为什么MSE Loss不适合处理分类任务?_第1张图片

其中,N为样本个数, o u t p u t i output_{i} outputi表示第i个样本的输出(经过softmax函数,输出概率之和为1), l a b e l i label_{i} labeli表示第i个样本的标签(对于猫狗二分类,标签为0/1), 那么 o u t p u t i [ l a b e l i ] output_{i}[label_{i}] outputi[labeli]表示的便是目标类别的输出概率。

NLL Loss函数以目标类别的预测概率作为输入,其曲线如下所示:
深入理解: 为什么MSE Loss不适合处理分类任务?_第2张图片

从上图可以看出:

  • 当目标类别的概率较低时,NLL非常高,且曲线下降速度非常快;
  • 当目标类别的预测概率大于0.5时(由于输出的概率之和为1,此时argmax得到的预测类别一定是目标类别), NLL较低,且已经趋于平缓/收敛;

因此,相比于MSE Loss, NLL Loss更适用于处理分类问题,而分类任务常用的交叉熵损失正是基于NLL Loss!!!(基于NLL Loss的交叉熵函数的pytorch保姆级复现见下期博客)

你可能感兴趣的:(深度学习,PyTorch,分类,机器学习,人工智能)