本文主要是讲述Softmax和CrossEntropy的公式推导,并用代码进一步佐证。
我们把 S o f t m a x Softmax Softmax输出的概率定义为 p i p_i pi:
S o f t m a x ( a i ) = p i = e a i ∑ j N e a j Softmax(a_i) = p_i = \frac {e^{a_i}} {\sum_j^N e^{a_j}} Softmax(ai)=pi=∑jNeajeai
模型输出 [ a 1 , a 2 , . . . , a N ] [a_1, a_2, ..., a_N] [a1,a2,...,aN],共N个值。
其中 a i a_i ai代表第 i i i个输出值, p i p_i pi代表第 i i i个输出值经过 S o f t m a x Softmax Softmax计算过后的概率。
且 p 1 + p 2 + . . . + p N = 1 p_1+p_2+...+p_N=1 p1+p2+...+pN=1
因为Softmax涉及到指数函数,且底数 e e e 大于1,在计算机中是可能会有溢出风险的。结合指数、对数函数的转换规则,我们可以制定一些数值稳定的优化策略。(当然这些是在框架中实现的,学习更多是为了扩展视野)
数值稳定的主要思路在于 a i a_i ai减去 A = [ a 1 , a 2 , . . . , a N ] A=[a_1, a_2, ..., a_N] A=[a1,a2,...,aN]中的最大值 m a x ( A ) max(A) max(A)
p i = e a i ∑ j N e a j = C ⋅ e a i C ⋅ ∑ j N e a j = e log ( C ) ⋅ e a i e log ( C ) ⋅ ∑ j N e a j = e a i + log ( C ) ∑ j N e a j + log ( C ) = e a i − m a x ( A ) ∑ j N e a j − m a x ( A ) \begin{aligned} p_i & = \frac {e^{a_i}} {\sum_j^N e^{a_j}}\\ & = \frac {C \cdot e^{a_i}} {C \cdot \sum_j^N e^{a_j}}\\ & = \frac {e^{\log(C)} \cdot e^{a_i}} {e^{\log(C)} \cdot \sum_j^N e^{a_j}}\\ & = \frac {e^{a_i + \log(C)}} {\sum_j^N e^{a_j + \log(C)}}\\ & = \frac {e^{a_i - max(A)}} {\sum_j^N e^{a_j - max(A)}}\\ \end{aligned} pi=∑jNeajeai=C⋅∑jNeajC⋅eai=elog(C)⋅∑jNeajelog(C)⋅eai=∑jNeaj+log(C)eai+log(C)=∑jNeaj−max(A)eai−max(A)
因为C是常数, l o g ( C ) log(C) log(C)也是常数,所以我们可以把 l o g ( C ) log(C) log(C)定义为-max(A),且并不会改变 p i p_i pi的计算结果。
减去A中的最大值,就能确保A中所有的值都不会上溢出。
我们把交叉熵公式定义为 H H H:
C r o s s E n t r o p y ( y i , p i ) = H ( y i , p i ) = − ∑ i N y i ⋅ log ( p i ) CrossEntropy(y_i, p_i) = H(y_i, p_i) = -\sum_i^Ny_i \cdot \log (p_i) CrossEntropy(yi,pi)=H(yi,pi)=−i∑Nyi⋅log(pi)
在多分类问题中,我们的 label 通常以独热码(one-hot)的形式展现和学习,因此在 Y = [ y 1 , y 2 , . . . , y N ] Y=[y_1, y_2, ..., y_N] Y=[y1,y2,...,yN] 中,只有一项为 1 1 1,其余项为 0 0 0,即 [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] [0, 0, ..., 1, ..., 0, 0] [0,0,...,1,...,0,0]。
所以 H ( y i , p i ) H(y_i, p_i) H(yi,pi) 也等于 − y i ⋅ log ( p i ) -y_i \cdot \log(p_i) −yi⋅log(pi), y i = 1 y_i=1 yi=1 对应label的类别。
据 S o f t m a x Softmax Softmax 公式可知,每个 p i p_i pi 均是所有 a a a 都有参与(分母)运算的,因此梯度的形式为:
∂ p i ∂ a j = ∂ ( e a i ∑ j N e a j ) ∂ a j \frac {\partial p_i} {\partial a_j} = \frac{\partial (\frac {e^{a_i}}{\sum_j^N e^{a_j}})}{\partial a_j} ∂aj∂pi=∂aj∂(∑jNeajeai)
而且 i i i 和 j j j 的关系要分类讨论。
这里要先复习下含分母的求导公式:
( h ( x ) g ( x ) ) ′ = h ′ ( x ) ⋅ g ( x ) − h ( x ) ⋅ g ′ ( x ) g ( x ) 2 (\frac{h(x)}{g(x)})^\prime = \frac{h'(x)\cdot g(x)-h(x)\cdot g'(x)}{g(x)^2} (g(x)h(x))′=g(x)2h′(x)⋅g(x)−h(x)⋅g′(x)
并且简化一下符号:
∑ j N e a j = ∑ \sum_j^Ne^{a_j} = \sum j∑Neaj=∑
当 i = j i=j i=j:
∂ p i ∂ a j = e a i ⋅ ∑ − e a i ⋅ e a j ∑ ⋅ ∑ = e a i ⋅ ( ∑ − e a j ) ∑ ⋅ ∑ = p i ⋅ ( 1 − p j ) \begin{aligned} \frac {\partial p_i} {\partial a_j} & = \frac{e^{a_i} \cdot \sum - e^{a_i} \cdot e^{a_j}}{\sum \cdot \sum} \\ & = \frac{e^{a_i} \cdot (\sum - e^{a_j})}{\sum \cdot \sum}\\ & = p_i \cdot (1 - p_j)\\ \end{aligned} ∂aj∂pi=∑⋅∑eai⋅∑−eai⋅eaj=∑⋅∑eai⋅(∑−eaj)=pi⋅(1−pj)
当 i ≠ j i \neq j i=j( e a i e^{a_i} eai相当于常数,导数为 0):
∂ p i ∂ a j = 0 ⋅ ∑ − e a i ⋅ e a j ∑ ⋅ ∑ = − p i ⋅ p j \begin{aligned} \frac {\partial p_i} {\partial a_j} & = \frac{0 \cdot \sum - e^{a_i} \cdot e^{a_j}}{\sum \cdot \sum} \\ & = - p_i \cdot p_j \end{aligned} ∂aj∂pi=∑⋅∑0⋅∑−eai⋅eaj=−pi⋅pj
∂ H ∂ a j = ∂ H ∂ p i ⋅ ∂ p i ∂ a j = − ∑ i y i 1 p i ⋅ ∂ p i ∂ a j − − − ① \frac {\partial H}{\partial a_j} = \frac {\partial H}{\partial p_i} \cdot \frac {\partial p_i}{\partial a_j} = -\sum_iy_i\frac{1}{p_i} \cdot \frac {\partial p_i}{\partial a_j}---① ∂aj∂H=∂pi∂H⋅∂aj∂pi=−i∑yipi1⋅∂aj∂pi−−−①
当 i = j i=j i=j:
① = − ∑ i = j y i 1 p i ⋅ p i ⋅ ( 1 − p j ) = − ∑ i = j y i ⋅ ( 1 − p j ) = − y i + y i p j ( 因 为 只 有 i , 可 以 把 ∑ 去 掉 ) − − − ② \begin{aligned} ① & = -\sum_{i=j}y_i\frac{1}{p_i} \cdot p_i \cdot (1 - p_j) \\ & = -\sum_{i=j}y_i\cdot (1 - p_j) \\ & = -y_i + y_ip_j (因为只有i,可以把\sum去掉)---② \end{aligned} ①=−i=j∑yipi1⋅pi⋅(1−pj)=−i=j∑yi⋅(1−pj)=−yi+yipj(因为只有i,可以把∑去掉)−−−②
当 i ≠ j i \neq j i=j:
① = − ∑ i ≠ j y i 1 p i ⋅ ( − p i ⋅ p j ) = ∑ i ≠ j y i p j − − − ③ \begin{aligned} ① & = -\sum_{i \neq j}y_i\frac{1}{p_i} \cdot (-p_i \cdot p_j) \\ & = \sum_{i \neq j}y_i p_j --- ③ \end{aligned} ①=−i=j∑yipi1⋅(−pi⋅pj)=i=j∑yipj−−−③
因为②和③其实是①的互斥情况,所以可以合并:
① = ② + ③ ( 记 住 在 ② 中 , i = j ) = − y i + y i p j + ∑ i ≠ j y i p j = − y i + ( ∑ i = j y i p j + ∑ i ≠ j y i p j ) = − y i + ∑ i N y i p j ( 因 为 y i 是 o n e − h o t , ∑ i N y i = 1 ) = p j − y j ( y i = y j ) \begin{aligned} ① & = ②+③(记住在②中,i=j) \\ & = -y_i + y_ip_j + \sum_{i \neq j}y_i p_j \\ & = -y_i + (\sum_{i=j}y_ip_j + \sum_{i \neq j}y_i p_j) \\ & = -y_i + \sum_i^N y_ip_j(因为y_i是one-hot,\sum_i^N y_i=1) \\ & = p_j - y_j(y_i=y_j) \end{aligned} ①=②+③(记住在②中,i=j)=−yi+yipj+i=j∑yipj=−yi+(i=j∑yipj+i=j∑yipj)=−yi+i∑Nyipj(因为yi是one−hot,i∑Nyi=1)=pj−yj(yi=yj)
整个Softmax+CrossEntropy的求导推导下来发现, H H H对于 a j a_j aj的偏导值,就是让他的 p j p_j pj去减对应的label值( y j y_j yj)。
举例P = [0.5, 0.3, 0.2],Y=[1, 0, 0],对应的导数就是 [-0.5, 0.3, 0.2]。
先看下x在softmax+cross entorpy前向计算并且BP后,所产生的梯度是多少,即 ∂ H ∂ a j \frac{\partial H}{\partial a_j} ∂aj∂H,在这个例子中分别对a1,a2,a3求偏导:
x = torch.randn((1, 3), requires_grad=True)
# tensor([[-0.3876, 0.2697, -1.6527]], requires_grad=True)
y = torch.randint(3, (1,), dtype=torch.int64)
# tensor([1])
loss = F.cross_entropy(x, y)
# F.cross_entropy含了softmax+cross_entropy
# 因此直接调用即可,无需先使用F.softmax
print(loss)
# tensor(0.5095, grad_fn=)
loss.backward()
print(x.grad)
# tensor([[ 0.3113, -0.3992, 0.0879]])
下面再看下 p i p_i pi 的值:
F.softmax(x, dim=1)
# tensor([[0.3113, 0.6008, 0.0879]], grad_fn=)
发现没有!发现没有!除了 a 1 a_1 a1 比 p 1 p_1 p1 减了1之外,其他都没变!正正验证了上面的公式推导!
总结一下,在多分类问题中,softmax+cross entropy是比较普遍,且计算速度较快的损失函数(loss function),因为它的梯度仅仅只用把概率值(pi)减去标签(yi)即可!
这在训练的初期,可以提供较快的训练速度,以提供后续优化的方向。当然,后续也包括对损失函数的优化!