Softmax计算技巧

初始做法

在softmax回归中,定义
y ^ = s o f t m a x ( o ) y ^ j = exp ⁡ ( o j ) ∑ k exp ⁡ ( o k ) ( 1 ) ( i = 1... n , k = 1... q ) \hat{\mathbf{y}} = \mathrm{softmax}(\mathbf{o})\quad \text\quad \hat{y}_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)} \text\quad (1)\\ (i=1...n,k=1...q) y^=softmax(o)y^j=kexp(ok)exp(oj)(1)(i=1...n,k=1...q)
对于任何标签 y y y 和模型预测 y ^ \hat{y} y^ ,损失函数为:
l ( y , y ^ ) = − ∑ j = 1 q y j log ⁡ y ^ j ( 2 ) l(\mathbf{y}, \hat{\mathbf{y}}) = - \sum_{j=1}^q y_j \log \hat{y}_j \text\quad (2) l(y,y^)=j=1qyjlogy^j(2)
( 1 ) (1) (1) 代入 ( 2 ) (2) (2) 中:
l ( y , y ^ ) = − ∑ j = 1 q y j log ⁡ exp ⁡ ( o j ) ∑ k = 1 q exp ⁡ ( o k ) = ∑ j = 1 q y j log ⁡ ∑ k = 1 q exp ⁡ ( o k ) − ∑ j = 1 q y j o j = log ⁡ ∑ k = 1 q exp ⁡ ( o k ) − ∑ j = 1 q y j o j . ( 3 ) \begin{split}\begin{aligned} l(\mathbf{y}, \hat{\mathbf{y}}) &= - \sum_{j=1}^q y_j \log \frac{\exp(o_j)}{\sum_{k=1}^q \exp(o_k)} \\ &= \sum_{j=1}^q y_j \log \sum_{k=1}^q \exp(o_k) - \sum_{j=1}^q y_j o_j\\ &= \log \sum_{k=1}^q \exp(o_k) - \sum_{j=1}^q y_j o_j. \end{aligned}\end{split} \text\quad (3) l(y,y^)=j=1qyjlogk=1qexp(ok)exp(oj)=j=1qyjlogk=1qexp(ok)j=1qyjoj=logk=1qexp(ok)j=1qyjoj.(3)
考虑相对于任何未规范化的预测 o j o_j oj 的导数,我们得到:
∂ o j l ( y , y ^ ) = exp ⁡ ( o j ) ∑ k = 1 q exp ⁡ ( o k ) − y j = s o f t m a x ( o ) j − y j ( 4 ) \partial_{o_j} l(\mathbf{y}, \hat{\mathbf{y}}) = \frac{\exp(o_j)}{\sum_{k=1}^q \exp(o_k)} - y_j = \mathrm{softmax}(\mathbf{o})_j - y_j \text\quad (4) ojl(y,y^)=k=1qexp(ok)exp(oj)yj=softmax(o)jyj(4)

问题1: exp ⁡ ( o k ) \exp(o_k) exp(ok) 可能特别大或特别小

softmax函数 y ^ j = exp ⁡ ( o j ) ∑ k exp ⁡ ( o k ) \hat y_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)} y^j=kexp(ok)exp(oj), 其中 o j o_j oj 是预测 o \mathbf{o} o 的概率分布。 o j o_j oj 是未规范化的预测的第 j j j 个元素。
如果 o j o_j oj 中的一些数值非常大, 那么 exp ⁡ ( o k ) \exp(o_k) exp(ok) 可能大于数据类型容许的最大数字,即上溢(overflow)。 这将使分母或分子变为 inf ⁡ \inf inf(无穷大), + ∞ + ∞ \frac{+\infty}{+\infty} ++最后得到的是 0 0 0 inf ⁡ \inf inf 或 nan(不是数字)的 y j ^ \hat{y_j} yj^
另一方面 exp ⁡ ( o k ) \exp(o_k) exp(ok) 都特别小, ∑ k exp ⁡ ( o k ) \sum_k \exp(o_k) kexp(ok) 在实际计算中为 0 0 0,这样就出现了 0 / 0 0/0 0/0 的错误。
在这些情况下,我们无法得到一个明确定义的交叉熵值。

解决这个问题的一个技巧是: 在继续softmax计算之前,先从所有 o k o_k ok 中减去 max ⁡ ( o ) \max(\mathbf{o}) max(o)。 这里可以看到每个 o k o_k ok 按常数进行的移动不会改变softmax的返回值:
y ^ j = exp ⁡ ( o j − max ⁡ ( o ) ) exp ⁡ ( max ⁡ ( o ) ) ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) exp ⁡ ( max ⁡ ( o ) ) = exp ⁡ ( o j − max ⁡ ( o ) ) ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) . ( 5 ) \begin{split}\begin{aligned} \hat y_j & = \frac{\exp(o_j - \max(\mathbf{o}))\exp(\max(\mathbf{o}))}{\sum_k \exp(o_k - \max(\mathbf{o}))\exp(\max(\mathbf{o}))} \\ & = \frac{\exp(o_j - \max(\mathbf{o}))}{\sum_k \exp(o_k - \max(\mathbf{o}))}. \end{aligned}\end{split} \text\quad (5) y^j=kexp(okmax(o))exp(max(o))exp(ojmax(o))exp(max(o))=kexp(okmax(o))exp(ojmax(o)).(5)
这样分子 max ⁡ ( exp ⁡ ( o j − max ⁡ ( o ) ) ) = 1 \max(\exp(o_j - \max(\mathbf{o})))=1 max(exp(ojmax(o)))=1,分母 ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) ≥ 1 \sum_k \exp(o_k - \max(\mathbf{o})) \ge 1 kexp(okmax(o))1

问题2: log ⁡ ( y ^ j ) = − ∞ \log{(\hat{y}_j)}=-\infty log(y^j)=

在正向、反向传播中都要计算公式 ( 2 ) (2) (2) 中的 log ⁡ ( y ^ j ) \log{(\hat{y}_j)} log(y^j),按照上面做归一化后,由于精度受限, exp ⁡ ( o j − max ⁡ ( o k ) ) \exp(o_j - \max(o_k)) exp(ojmax(ok)) 将有接近零的值,即下溢(underflow),此时 log ⁡ ( y ^ j ) = − ∞ \log{(\hat{y}_j)}=-\infty log(y^j)=
但是实际上,这个问题在实际数据运算时可以避免掉,在数学上,有下面的运算(永远可行!)
log ⁡ ( y ^ j ) = log ⁡ ( exp ⁡ ( o j − max ⁡ ( o ) ) ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) ) = log ⁡ ( exp ⁡ ( o j − max ⁡ ( o ) ) ) − log ⁡ ( ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) ) = o j − max ⁡ ( o ) − log ⁡ ( ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) ) . ( 6 ) \begin{split}\begin{aligned} \log{(\hat y_j)} & = \log\left( \frac{\exp(o_j - \max(\mathbf{o}))}{\sum_k \exp(o_k - \max(\mathbf{o}))}\right) \\ & = \log{(\exp(o_j - \max(\mathbf{o})))}-\log{\left( \sum_k \exp(o_k - \max(\mathbf{o})) \right)} \\ & = o_j - \max(\mathbf{o}) -\log{\left( \sum_k \exp(o_k - \max(\mathbf{o})) \right)}. \end{aligned}\end{split} \text\quad (6) log(y^j)=log(kexp(okmax(o))exp(ojmax(o)))=log(exp(ojmax(o)))log(kexp(okmax(o)))=ojmax(o)log(kexp(okmax(o))).(6)
这样在计算 log ⁡ ( y ^ j ) \log{(\hat{y}_j)} log(y^j) 时不用先计算 y ^ j \hat{y}_j y^j 然后计算 log ⁡ ( y ^ j ) \log(\hat{y}_j) log(y^j),而是计算公式 ( 6 ) (6) (6) 的最后部分:
o j − max ⁡ ( o ) − log ⁡ ( ∑ k exp ⁡ ( o k − max ⁡ ( o ) ) ) o_j - \max(\mathbf{o}) -\log{\left( \sum_k \exp(o_k - \max(\mathbf{o})) \right)} ojmax(o)log(kexp(okmax(o)))
这一部分已经在机器运算时不会出现问题。

启发

由于数据结构的限制,数值计算在实际运算中要避免 log ⁡ ( 0 ) \log(0) log(0) 0 0 \frac{0}{0} 00 ∞ ∞ \frac{\infty}{\infty} 的情况,考虑在数学能不能进一步变换,找到等价的表达式是一个很好的思路,本文中的做法在其它问题中也可借鉴。

参考

3.7. softmax回归的简洁实现

你可能感兴趣的:(深度学习算法,人工智能,深度学习)