通俗理解简单的交叉熵损失函数

说起交叉熵损失函数「Cross Entropy Loss」,我们都不陌生,脑海中会马上浮现出它的公式:


我们已经对这个交叉熵函数的形式非常熟悉,多数情况下都是直接拿来使用。那么,它是怎么来的?为什么它能表征真实样本标签和预测概率之间的差值?上面的交叉熵函数是否有其它变种?接下来我将尽可能通俗地回答上面这几个问题。


(一)交叉熵损失函数的数学原理


我们知道,在二分类问题模型,例如逻辑回Logistic Regressio、神经网络等,真实样本的标签为 [0,1],分别表示负类、正类。模型的最后通常会经过一个 Sigmoid 函数,输出一个概率值,这个概率值反映了预测为正类的可能性:概率越大,可能性越大。

Sigmoid 函数的表达式和图形如下所示:

通俗理解简单的交叉熵损失函数_第1张图片


其中 s 是模型上一层的输出,Sigmoid 函数有这样的特点:s = 0 时,g(s) = 0.5;s >> 0 时, g ≈ 1,s << 0 时,g ≈ 0。显然,g(s) 将前一级的线性输出映射到 [0,1] 之间连续的数值概率上。这里的 g(s) 就是交叉熵公式中的模型预测输出 。

我们说了,预测输出,即Sigmoid函数的输出表征了当前样本标签为1的概率:

很明显,当前样本标签为 0 的概率就可以表达成:

如果我们从极大似然性的角度出发,把上面两种情况整合到一起:

对于这个式子而言,

当真实样本标签 y = 0 时,上面式子第一项就为 1,概率等式转化为:

当真实样本标签 y = 1 时,上面式子第二项就为 1,概率等式转化为:

两种情况下概率表达式与之前的完全一致,只不过我们把两种情况整合在一起了。

重点看一下整合之后的概率表达式,我们希望的是概率 P(y|x) 越大越好。首先,我们对 P(y|x) 引入log函数,因为log运算并不会影响函数本身的单调性。则有:



我们希望log P(y|x)越大越好,反过来,只要 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x)即可。则得到损失函数为:

非常简单,我们已经推导出了单个样本的损失函数,是如果是计算 N 个样本总的损失函数,只要将N个Loss叠加起来就可以了:

这样,我们已经完整地实现了交叉熵损失函数的推导过程。


(二)交叉熵损失函数的直观理解


我们现在已经知道了交叉熵损失函数的推导过程,能不能从更直观的角度去理解这个表达式,而不是仅仅记住这个公式。接下来,我们从图形的角度,分析交叉熵函数,以此来加深大家的理解。

首先,写出单个样本的交叉熵损失函数:


我们知道,当 y = 1 时:



这时候,L 与预测输出的关系如下图所示:


通俗理解简单的交叉熵损失函数_第2张图片


看了 L 的图形,简单明了!横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大。因此,函数的变化趋势完全符合实际需要的情况。


当 y = 0 时:

这时候,L 与预测输出的关系如下图所示:


通俗理解简单的交叉熵损失函数_第3张图片


同样,预测输出y^越接近真实样本标签0,损失函数L越小;预测函数越接近1,L越大。函数的变化趋势也完全符合实际需要的情况。

上面两图,可以帮助我们对交叉熵损失函数有直观的理解。无论真实样本标签 y是0还是1,L都表征了预测输出y^与y的差距。


另外,重点提一点的是,从图形中我们可以发现:预测输出y^与y差得越多,L的值越大,也就是说对当前模型的“ 惩罚”越大,而且是非线性增大,是一种类似指数增长的级别。这是由log函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签y。


(三)交叉熵损失函数的其它形式

交叉熵损失函数还有其它形式,之前介绍的是一个典型的形式。接下来将从另一个角度推导新的交叉熵损失函数。

这种形式下假设真实样本的标签为 +1 和 -1,分别表示正类和负类。有个已知的知识点是Sigmoid 函数具有如下性质:



之前提及 y = +1 时,下列等式成立:


如果 y = -1 时,并引入 Sigmoid 函数的性质,下列等式成立:


重点来了,因为 y 取值为 +1 或 -1,可以把 y 值带入,将上面两个式子整合到一起:


这个比较好理解,分别令 y = +1 和 y = -1 就能得到上面两个式子。

接下来,同样引入 log 函数(极大似然),得到:


要让概率最大,反过来,只要其负数最小即可。那么就可以定义相应的损失函数为:


Sigmoid函数的表达式g(ys) 带入:


L 就是我们要推导的交叉熵损失函数。如果是 N 个样本,其交叉熵损失函数为:


接下来,我们从图形化直观角度来看。当 y = +1 时:

这时候,L 与上一层得分函数 s 的关系如下图所示:


通俗理解简单的交叉熵损失函数_第4张图片


横坐标是 s,纵坐标是 L。显然,s 越接近真实样本标签 1,损失函数 L 越小;s 越接近 -1,L 越大。

另一方面,当 y = -1 时:

这时候,L 与上一层得分函数 s 的关系如下图所示:


通俗理解简单的交叉熵损失函数_第5张图片


同样,s 越接近真实样本标签 -1,损失函数 L 越小;s 越接近 +1,L 越大。


(四)总结

本文主要介绍了交叉熵损失函数的数学原理和推导过程,也从不同角度介绍了交叉熵损失函数的两种形式。第一种形式在实际应用中更加常见,例如神经网络等复杂模型;第二种多用于简单的逻辑回归模型。




你可能感兴趣的:(特征工程,机器学习,机器学习与人工智能,LR,损失函数,交叉熵)