深度学习中的梯度消失、梯度爆炸问题的原因以及解决方法

本文简要介绍梯度消失 (gradient vanishing) 和梯度爆炸 (gradient exploding) 问题,并给出一些可行的解决方法。

文章目录

    • 1. 梯度推导过程
    • 2. sigmoid函数的性质
    • 3. 梯度消失与梯度爆炸的原因
    • 4. 一些其他的激活函数
      • 4.1 tanh函数
      • 4.2 ReLU函数
    • 5. 解决方案
      • 5.1 Batch Normalization 批标准化
      • 5.2 选用ReLU、leak ReLU、eReLU等激活函数
      • 5.3 使用梯度剪切预防梯度爆炸
      • 5.4 使用正则化 (regularization) 预防梯度爆炸
      • 5.5 使用残差网络ResNet
      • 5.6 LSTM网络

神经网络在更新参数的过程中,使用反向传播 (Backpropagation) 算法求得各层网络的梯度,可以看作是 神经网络式的链式法则。反向传播过程是导致梯度消失和梯度爆炸问题的主要原因,而且随着网络的深度增加,这些问题越发明显。

深度神经网络在解决复杂问题时,比浅层神经网络的效果更好,但也无法避免地会遇到梯度消失或梯度爆炸问题。

梯度消失的常见情形:

  1. 深层网络
  2. 使用了sigmoid激活函数

梯度爆炸的常见情形:

  1. 深层网络
  2. 参数的初始值过大

1. 梯度推导过程

我们用一个例子来说明:
在这里插入图片描述

上图是一个5层的网络,每一层的输入记为 x i x_i xi, bias记为 b i b_i bi,权重记为 w i w_i wi,输出为 z i z_i zi,激活后的输出记为 h i h_i hi,sigmoid函数用 σ \sigma σ表示,则有:
z i = x i w i + b i z_i = x_iw_i+b_i zi=xiwi+bi

h i = σ ( z i ) h_i = \sigma(z_i) hi=σ(zi)

x 4 = h 3 , x 3 = h 2 , x 2 = h 1 x_4 = h_3, \quad x_3 = h_2, \quad x_2 = h_1 x4=h3,x3=h2,x2=h1

设最终的损失函数为 L L L,则利用反向传播方法求 L L L关于 w 1 w_1 w1的梯度过程如下:
(1) ∂ L ∂ w 1 = ∂ L ∂ h 4 ∂ h 4 ∂ z 4 ∂ z 4 ∂ x 4 ∂ x 4 ∂ z 3 ∂ z 3 ∂ x 3 ∂ x 3 ∂ z 2 ∂ z 2 ∂ x 2 ∂ x 2 ∂ z 1 ∂ z 1 ∂ w 1 = ∂ L ∂ h 4 σ ′ ( z 4 ) w 4 σ ′ ( z 3 ) w 3 σ ′ ( z 2 ) w 2 σ ′ ( z 1 ) x 1 \begin{aligned} \frac{\partial L}{\partial w_1} &= \frac{\partial L}{\partial h_4} \frac{\partial h_4}{\partial z_4}\frac{\partial z_4}{\partial x_4}\frac{\partial x_4}{\partial z_3}\frac{\partial z_3}{\partial x_3}\frac{\partial x_3}{\partial z_2}\frac{\partial z_2}{\partial x_2}\frac{\partial x_2}{\partial z_1} \frac{\partial z_1}{\partial w_1} \\ &= \frac{\partial L}{\partial h_4}\sigma'(z_4)w_4\sigma'(z_3)w_3\sigma'(z_2)w_2\sigma'(z_1)x_1 \tag{1} \end{aligned} w1L=h4Lz4h4x4z4z3x4x3z3z2x3x2z2z1x2w1z1=h4Lσ(z4)w4σ(z3)w3σ(z2)w2σ(z1)x1(1)

2. sigmoid函数的性质

sigmoid函数的定义为:
σ = 1 1 + e − z \sigma = \frac{1}{1+e^{-z}} σ=1+ez1
导数:
σ ′ = e − z ( 1 + e − z ) 2 = σ ⋅ ( 1 − σ ) \sigma' = \frac{e^{-z}}{(1+e^{-z})^2}=\sigma\cdot(1-\sigma) σ=(1+ez)2ez=σ(1σ)
由上式可知, σ ′ \sigma' σ的变化趋势为:先递增,然后递减,并在 σ = 1 2 \sigma=\dfrac{1}{2} σ=21,也即 z = 0 z=0 z=0时取得最大值 1 4 \dfrac{1}{4} 41。我们画出 σ ′ \sigma' σ的图像,如下图所示:
在这里插入图片描述
因此,我们得到:
σ ′ ( z ) ≤ 1 4 \sigma'(z)\le \frac{1}{4} σ(z)41

3. 梯度消失与梯度爆炸的原因

神经网络的参数采用随机初始化的方法,取值往往会小于1,也即 ∣ w ∣ < 1 |w|<1 w<1,所以式(1)中,梯度 ∂ L ∂ w 1 \dfrac{\partial L}{\partial w_1} w1L随着网络深度的增加,数值一直在减小,并且传递到第 i i i层时,数值小于第 i + 1 i+1 i+1层的 1 4 \dfrac{1}{4} 41

如果网络非常深,就可能会使得前几层网络参数的梯度接近于0,也就发生了梯度消失现象。

另外,如果网络的初始化参数 ∣ w ∣ |w| w比较大,会造成 ∣ σ ′ ( z ) w ∣ > 1 |\sigma'(z)w|>1 σ(z)w>1的情况,此时式(1)中,梯度 ∂ L ∂ w 1 \dfrac{\partial L}{\partial w_1} w1L随着网络深度的增加,数值一直增大。如果网络非常深,就可能会使得前几层网络参数的梯度非常大,也就发生了梯度爆炸现象。

4. 一些其他的激活函数

4.1 tanh函数

tanh ⁡ ( z ) = e z − e − z e z + e − z \tanh(z)=\frac{e^{z}-e^{-z}}{e^z+e^{-z}} tanh(z)=ez+ezezez
tanh ⁡ ′ ( z ) = 1 − tanh ⁡ 2 ( z ) \tanh'(z)=1-\tanh^2(z) tanh(z)=1tanh2(z)
tanh函数和其导数的图形如下:
在这里插入图片描述
由上图可知,tanh函数的导数值也小于1,比sigmoid函数略好,但仍会导致梯度消失。

4.2 ReLU函数

ReLU ( z ) = max ⁡ ( x , 0 ) \text{ReLU}(z) = \max(x, 0) ReLU(z)=max(x,0)
ReLU ′ ( z ) = { 0 , z ≤ 0 1 , z > 0 \text{ReLU}'(z) = \left\{\begin{aligned}&0,&&z\le 0\\ &1,&&z>0\end{aligned}\right. ReLU(z)={0,1,z0z>0
ReLU函数和其导数的图形如下:
在这里插入图片描述
由上图可知,ReLU函数的导数,在正值部分恒为1,因此不会导致梯度消失或梯度爆炸问题。

另外ReLU函数还有一些优点

  1. 计算方便,计算速度快
  2. 解决了梯度消失问题,收敛速度快

ReLU函数的缺点

  1. 输出不是zero-centered,同sigmoid一样
  2. 某些神经元可能永远不会激活,导致相应的参数永远不会被更新。(解决方法:减小learning rate;Xavier初始化方法)
  3. 数据没有被压缩,数据的幅值会随着网络层数的增加而不断扩张。

5. 解决方案

5.1 Batch Normalization 批标准化

BN将网络中每一层的输出标准化为正态分布,并使用缩放平移参数对标准化之后的数据分布进行调整,可以将集中在梯度饱和区的原始输出拉向线性变化区,增大梯度值,缓解梯度消失问题,并加快网络的学习速度。

详细解读,请参考 Batch Normalization 批标准化及其相关数学原理和推导

5.2 选用ReLU、leak ReLU、eReLU等激活函数

ReLU函数的性质,在上文已有相关表述。

5.3 使用梯度剪切预防梯度爆炸

为梯度设置一个阈值,在更新过程中,超过这个阈值的梯度被强制限定在这一范围,防止梯度爆炸。

5.4 使用正则化 (regularization) 预防梯度爆炸

在损失函数中加入参数的正则化项,在防止过拟合的同时,还可以防止梯度爆炸。
L = ∥ w x − y ∥ 2 + γ ∥ w ∥ 2 L = \left\Vert wx - y\right\Vert^2 + \gamma\left\Vert w\right\Vert^2 L=wxy2+γw2

5.5 使用残差网络ResNet

使用ResNet可以轻松搭建几百层、上千层的网络,而不用担心梯度消失问题。

5.6 LSTM网络

你可能感兴趣的:(人工智能/深度学习/机器学习)