分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系

1. 引言

我们都知道损失函数有很多种:均方误差(MSE)、SVM的合页损失(hinge loss)、交叉熵(cross entropy)。这几天看论文的时候产生了疑问:为啥损失函数很多用的都是交叉熵(cross entropy)?其背后深层的含义是什么?如果换做均方误差(MSE)会怎么样?下面我们一步步来揭开交叉熵的神秘面纱。

2. 交叉熵的来源

2.1 信息量

一条信息的信息量大小和它的不确定性有很大的关系。一句话如果需要很多外部信息才能确定,我们就称这句话的信息量比较大。比如你听到“云南西双版纳下雪了”,那你需要去看天气预报、问当地人等等查证(因为云南西双版纳从没下过雪)。相反,如果和你说“人一天要吃三顿饭”,那这条信息的信息量就很小,因为条信息的确定性很高。

那我们就能将事件x_0的信息量定义如下(其中p(x_0)表示事件x_0发生的概率):

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第1张图片

概率总是一个0-1之间的值,-log(x)的图像如上

2.2 熵

信息量是对于单个事件来说的,但是实际情况一件事有很多种发生的可能,比如掷骰子有可能出现6种情况,明天的天气可能晴、多云或者下雨等等。熵是表示随机变量不确定的度量,是对所有可能发生的事件产生的信息量的期望。公式如下:

n表示事件可能发生的情况总数

其中一种比较特殊的情况就是掷硬币,只有正、反两种情况,该种情况(二项分布或者0-1分布)熵的计算可以简化如下:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第2张图片

p(x)代表掷正面的概率,1-p(x)则表示掷反面的概率(反之亦然)

2.3 相对熵

相对熵又称KL散度,用于衡量对于同一个随机变量x的两个分布p(x)和q(x)之间的差异。在机器学习中,p(x)常用于描述样本的真实分布,例如[1,0,0,0]表示样本属于第一类,而q(x)则常常用于表示预测的分布,例如[0.7,0.1,0.1,0.1]。显然使用q(x)来描述样本不如p(x)准确,q(x)需要不断地学习来拟合准确的分布p(x)。

KL散度的公式如下:

n表示事件可能发生的情况总数

KL散度的值越小表示两个分布越接近。

2.4 交叉熵

我们将KL散度的公式进行变形,得到:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第3张图片

前半部分就是p(x)的熵,后半部分就是我们的交叉熵:

机器学习中,我们常常使用KL散度来评估predict和label之间的差别,但是由于KL散度的前半部分是一个常量,所以我们常常将后半部分的交叉熵作为损失函数,其实二者是一样的。

对二分类情况下,有:

这里yi就是label,\hat{y_i}是logits经过softmax归一化到(0,1)之间的结果。在yi取0或1的情况下,信息熵部分为0,所以KL散度就等于交叉熵,但是在一些情况下,例如使用标签平滑处理技术后,yi的取值不是0或1,这时候,KL散度相当于在交叉熵的基础上减去了一个常数,KL散度作为损失函数去优化模型的效果和交叉熵是完全一样的,但是在这种情况下当模型完美拟合标签的情况下KL散度的最小值可取到0,而此时交叉熵能够取到的最小值是信息熵不为0,所以这种情况下使用KL散度更符合我们对Loss的一般认识。用pytorch实现可以看到KLDivLoss和CELoss是相等的:

import torch.nn.functional as F
import torch
import torch.nn as nn
# nn.CrossEntropyLoss() 和  KLDivLoss 关系

y_pred = torch.tensor([[10.0, 0.0, -10.0], [8.0, 8.0, 8.0]])
y_true = torch.tensor([0, 2])
ce = nn.CrossEntropyLoss(reduction="none")(y_pred, y_true)
print(ce)
'''
输出shape是2,tensor([4.5418e-05, 1.0986e+00])
'''

# NLLLoss要求target只能是第几类下标,例如[0,2]表示[label0,label2],转成onehot就是[[1,0,0],[0,0,1]]
nll_log_softmax = nn.NLLLoss(reduction="none")(F.log_softmax(y_pred, dim=-1), y_true)
print(nll_log_softmax)
'''
输出shape是2,tensor([4.5418e-05, 1.0986e+00])
'''

one_hot = F.one_hot(y_true) #将第几类的下标转换成onehot形式,例如输入[0,2]表示[label0,label2],输出onehot就是[[1,0,0],[0,0,1]]
'''
# KLDivLoss要求target为float形式编码,one_hot是longtensor,
  所以要one_hot.float();如果是普通的logics,要过一下softmax

# KLDivLoss也要求Logits经过LogSoftmax激活。LogSoftmax会把(-inf,inf)的Logits映射到(0,1)再映射到(-inf,0):
  当用NLLLoss时,刚好多个负号loss变成(0,inf);当用KLDivLoss时,刚好多个熵。

回顾klLoss的公式 p_i*log(p_i/q_i),其中p_i是(0,1)范围内的targets
q_i是将logits映射到(0,1)范围内的结果,所以p_i和q_i都是(0,1)之间
KLDivLoss这个函数的特点就是把log(q_i)这一步扔给输入自己算,这个函数管的只是p_i*log(p_i)-p_i*input
  NLLLoss这个函数的特点就是把p_i*log(p_i)也没了,只有-p_i*input,所以和LogSoftmax组合起来是CE
'''

kl = nn.KLDivLoss(reduction="none")(F.log_softmax(y_pred, dim=-1), one_hot.float())
print(kl) #输出shape是2*3
'''
tensor([[4.5418e-05, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.0986e+00]])
'''

a = F.softmax(torch.randn(2,3))
print(nn.KLDivLoss(reduction="none")(torch.log(a), a))
'''
输出是
tensor([[0., 0., 0.],
        [0., 0., 0.]])

回顾klLoss的公式 p_i*log(p_i/q_i),其中p_i是(0,1)范围内的targets
q_i是将logits映射到(0,1)范围内的结果,所以p_i和q_i都是(0,1)之间
KLDivLoss这个函数的特点就是把log(q_i)这一步扔给输入自己算,这个函数管的只是p_i*log(p_i)-p_i*input
  NLLLoss这个函数的特点就是把p_i*log(p_i)也没了,只有-p_i*input,所以和LogSoftmax组合起来是CE
'''

3. 交叉熵作为loss函数的直觉

在回归问题中,我们常常使用均方误差(MSE)作为损失函数,其公式如下:

m表示样本个数,loss表示的是m个样本的均值

其实这里也比较好理解,因为回归问题要求拟合实际的值,通过MSE衡量预测值和实际值之间的误差,可以通过梯度下降的方法来优化。而不像分类问题,需要一系列的激活函数(sigmoid、softmax)来将预测值映射到0-1之间,这时候再使用MSE的时候就要好好掂量一下了,为啥这么说,请继续看:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第4张图片

sigmoid加MES的基本公式

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第5张图片

gradient推导过程

上面复杂的推导过程,其实结论就是下面一张图:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第6张图片

C就是?的J,sigma就是sigmoid函数,a就是predict

从以上公式可以看出,w和b的梯度跟激活函数的梯度成正比,激活函数的梯度越大,w和b的大小调整得越快,训练收敛得就越快。而我们都知道sigmoid函数长这样:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第7张图片

图片来自:https://blog.csdn.net/u014313009/article/details/51043064

在上图的绿色部分,初始值是0.98,红色部分初始值是0.82,假如真实值是0。直观来看那么0.82下降的速度明显高于0.98,但是明明0.98的误差更大,这就导致了神经网络不能像人一样,误差越大,学习的越快。

但是如果我们把MSE换成交叉熵会怎么样呢?

x表示样本,n表示样本的总数

重新计算梯度:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第8张图片

推导过程

另外sigmoid有一个很好的性质:

我们从结果可以看出梯度中不再含有sigmoid的导数,有的是sigmoid的值和实际值之间的差,也就满足了我们之前所说的错误越大,下降的越快。

这也就是在分类问题中常用cross entropy 而不是 MSE的原因了:原因一,使用交叉熵loss下降的更快;原因二,使用交叉熵是凸优化,MSE是非凸优化

我们从最简单的线性回归开始讨论:
线性回归(回归问题)使用的是平方损失:

因为这个函数  是凸函数,直接求导等于零,即可求出解析解,很简单。但是对于逻辑回归则不行(分类问题)【注意:逻辑回归不是回归!是分类!!】因为如果逻辑回归也用平方损失作为损失函数,则:

其中  表示样本数量。
上式是非凸的,不能直接求解析解,而且不宜优化,易陷入局部最优解,即使使用梯度下降也很难得到全局最优解。如下图所示:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第9张图片

下一个问题,回归问题能用交叉熵吗 怎么用?:

回归问题常用mse作为损失函数,这里面一个隐含的预设是数据误差符合高斯分布。交叉熵则是以数据分布服从多项式分布为前提。因此本质上回归应该用什么样的损失函数取决于数据分布。损失函数的选择本身也是一种先验偏好,选择mse意味着你认为数据误差符合高斯分布,选择交叉熵则表示你倾向于认为数据接近多项式分布。如果你的先验直觉比较准确,符合实际情况,那模型效果应该会更好一些。 多项式分布一般和离散数据相关,但如果连续数据分桶后接近多项式分布,那选用mse可能就不合时宜了。那么如何使用交叉熵损失建模回归问题呢? 首先回顾下交叉熵损失函数,以二分类问题为例:

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第10张图片

分类问题中为什么用交叉熵而不用MSE KL散度和交叉熵的关系_第11张图片

转载自 简单的交叉熵,你真的懂了吗? - 知乎
交叉熵损失(Cross-entropy)和平方损失(MSE)究竟有何区别?

分类必然交叉熵,回归无脑MSE?未必 - 知乎

你可能感兴趣的:(算法,概率论,机器学习,人工智能)