在处理机器学习或深度学习问题时,损失/成本函数用于在训练期间优化模型。目标几乎总是最小化损失函数。损失越低,模型越好。交叉熵损失是最重要的成本函数。它用于优化分类模型。对交叉熵的理解取决于对 Softmax 激活函数的理解。我在下面写了另一篇文章来涵盖这个先决条件
考虑一个 4
类分类任务,其中图像被分类为狗、猫、马或猎豹。
上图中,Softmax 将 logits 转换为概率。交叉熵的目的是获取输出概率(P)并测量与真值的距离(如下图所示)。
对于上面的示例,类 dog
的所需输出是 [1,0,0,0]
,但模型输出 [0.775, 0.116, 0.039, 0.070]
。
目标是使模型输出尽可能接近期望的输出(真值)。在模型训练过程中,模型权重会相应迭代调整,目的是最小化交叉熵损失。调整权重的过程定义了模型训练,随着模型不断训练并且损失最小化,我们说模型正在学习。
交叉熵的概念可以追溯到信息论领域,克劳德·香农 (Claude Shannon) 在 1948 年引入了熵的概念。在深入研究交叉熵成本函数之前,让我们先介绍一下熵。
随机变量 X 的熵是变量可能结果固有的不确定性水平。
对于 p(x)
— 概率分布和随机变量 X,熵定义如下:
方程 1:熵的定义。注意 log 以 2 为底计算。
负号的原因: log(p(x))<0
对于 (0,1)
中的所有 p(x)
。 p(x) 是概率分布,因此值必须介于 0 和 1 之间。
log(x) 的绘图。对于介于 0 和 1 之间的 x 值,log(x) <0(负数)。
熵 H(x)
的值越大,概率分布的不确定性越大,值越小,不确定性越小。
考虑以下 3 个具有形状的“容器”:三角形和圆形
容器1:选择三角形的概率是26/30,选择圆形的概率是4/30。因此,选择一种形状和/或不选择另一种形状的概率更加确定。
容器 2:选择三角形的概率为 14/30,否则为 16/30。几乎有 50-50 的机会选择任何特定形状。选择给定形状的确定性低于 1。
容器 3:从容器 3 中选取的形状很可能是圆形。选取圆形的概率为 29/30,选取三角形的概率为 1/30。非常确定所选择的形状将是圆形。
让我们计算熵,以便我们确定我们对选择给定形状的确定性的断言。
正如预期的那样,第一个和第三个容器的熵小于第二个容器。这是因为在容器 1 和 3 中选择给定形状的概率比在容器 2 中更加确定。现在我们可以继续讨论交叉熵损失函数。
也称为对数损失、对数损失或逻辑损失。将每个预测类别概率与实际类别所需输出 0 或 1 进行比较,并计算分数/损失,根据概率与实际预期值的差距对概率进行惩罚。惩罚本质上是对数的,对于接近 1 的大差异产生大分数,对于趋向于 0 的小差异产生小分数。
在训练期间调整模型权重时使用交叉熵损失。目的是最小化损失,即损失越小模型越好。完美模型的交叉熵损失为 0。
交叉熵定义为:
对于二元分类(具有两个类别 - 0 和 1 的分类任务),我们将二元交叉熵定义为
二元交叉熵通常计算为所有数据示例的平均交叉熵,即
考虑具有以下 Softmax 概率 (S) 和标签 (T) 的分类问题。目标是根据这些信息计算交叉熵损失。
Logits(S) 和单热编码真值标签(T) 以及分类交叉熵损失函数用于测量预测概率和真值标签之间的“距离”。
分类交叉熵计算如下
Softmax 是连续可微函数。这使得计算损失函数对于神经网络中每个权重的导数成为可能。该属性允许模型相应地调整权重以最小化损失函数(模型输出接近真实值)。
假设经过一些模型训练迭代后,模型输出以下 logits 向量
0.095
小于之前的损失,即 0.3677
暗示模型正在学习。优化过程(调整权重以使输出接近真实值)持续到训练结束。
一般深度学习框架提供了以下交叉熵损失函数:二元、分类、稀疏分类交叉熵损失函数。
分类交叉熵和稀疏分类交叉熵都具有与等式 2 中定义的相同的损失函数。两者之间的唯一区别在于如何定义真值标签。
[1,0,0]
、 [0,1,0]
和 [0,0,1].
[1]
、 [2]
和 [3]
。@article{Koech2023Aug,
author = {Koech, Kiprono Elijah},
title = {{Cross-Entropy Loss Function - Towards Data Science}},
journal = {Medium},
year = {2023},
month = aug,
urldate = {2023-11-29},
publisher = {Towards Data Science},
language = {english},
url = {https://towardsdatascience.com/cross-entropy-loss-function-f38c4ec8643e}
}