DNN中的梯度消失/爆炸原因

DNN中的梯度消失/爆炸原因

梯度消失/梯度爆炸是深度学习中老生常谈的话题,这篇博客主要是对DNN中的梯度消失/梯度爆炸原因进行公式层面上的直观理解。

DNN中的梯度消失/爆炸原因_第1张图片

如上图所示,假设有2个隐层,前向传播公式:

f 1 = σ ( w 1 x + b 1 ) , z 1 = w 1 x + b 1 f_1 = \sigma(w_1x+b_1),z_1 = w_1x+b_1 f1=σ(w1x+b1)z1=w1x+b1

f 2 = σ ( w 2 f 1 + b 2 ) , z 2 = w 2 f 1 + b 2 f_2 = \sigma(w_2f_1+b_2),z_2 = w_2f_1+b_2 f2=σ(w2f1+b2)z2=w2f1+b2

f 3 = σ ( w 3 f 2 + b 3 ) , z 3 = w 3 f 2 + b 3 f_3 = \sigma(w_3f_2+b_3),z_3 = w_3f_2+b_3 f3=σ(w3f2+b3)z3=w3f2+b3

f 3 f_3 f3是输出层的神经元,所以可以认为 l o s s loss loss是关于 f 3 f_3 f3的函数。

l o s s loss loss反向传播的时候,我们可以对权重 w 3 , w 2 , w 1 w_3, w_2, w_1 w3,w2,w1进行更新:

∂ l o s s ∂ w 3 = ∂ l o s s ∂ f 3 ∂ f 3 ∂ w 3 = ∂ l o s s ∂ f 3 σ ′ ( w 3 f 2 + b 3 ) f 2 \frac{\partial loss}{\partial w_3} = \frac{\partial loss}{\partial f_3} \frac{\partial f_3}{\partial w_3} = \frac{\partial loss}{\partial f_3} \sigma^{'}(w_3f_2+b_3)f_2 w3loss=f3lossw3f3=f3lossσ(w3f2+b3)f2

∂ l o s s ∂ w 2 = ∂ l o s s ∂ f 3 ∂ f 3 ∂ f 2 ∂ f 2 ∂ w 2 = ∂ l o s s ∂ f 3 σ ′ ( w 3 f 2 + b 3 ) w 3 σ ′ ( w 2 f 1 + b 2 ) f 1 \frac{\partial loss}{\partial w_2} = \frac{\partial loss}{\partial f_3} \frac{\partial f_3}{\partial f_2} \frac{\partial f_2}{\partial w_2} = \frac{\partial loss}{\partial f_3} \sigma^{'}(w_3f_2+b_3)w_3 \sigma^{'}(w_2f_1+b_2)f_1 w2loss=f3lossf2f3w2f2=f3lossσ(w3f2+b3)w3σ(w2f1+b2)f1

∂ l o s s ∂ w 1 = ∂ l o s s ∂ f 3 ∂ f 3 ∂ f 2 ∂ f 2 ∂ f 1 ∂ f 1 ∂ w 1 = ∂ l o s s ∂ f 3 σ ′ ( w 3 f 2 + b 3 ) w 3 σ ′ ( w 2 f 1 + b 2 ) w 2 σ ′ ( w 1 x + b 1 ) x \frac{\partial loss}{\partial w_1} = \frac{\partial loss}{\partial f_3} \frac{\partial f_3}{\partial f_2} \frac{\partial f_2}{\partial f_1} \frac{\partial f_1}{\partial w_1} = \frac{\partial loss}{\partial f_3} \sigma^{'}(w_3f_2+b_3)w_3 \sigma^{'}(w_2f_1+b_2)w_2 \sigma^{'}(w_1x+b_1)x w1loss=f3lossf2f3f1f2w1f1=f3lossσ(w3f2+b3)w3σ(w2f1+b2)w2σ(w1x+b1)x

根据上面规律,我们可以把 x x x写成 f 0 f_0 f0,当有n-1层隐层时, f n f_n fn是输出,如果要求 w l w_l wl也就是第 l l l层的权重,反向传播中涉及的偏导计算为:

∂ l o s s ∂ w l = ∂ l o s s ∂ f n ∏ i = l n σ ′ ( w i f i − 1 + b i ) ∏ i = l + 1 n w i f l − 1 \frac{\partial loss}{\partial w_l } = \frac{\partial loss }{\partial {f_n} } \prod_{i=l}^{n}\sigma^{'}(w_if_{i-1} + b_i)\prod_{i=l+1}^{n}w_i f_{l-1} wlloss=fnlossi=lnσ(wifi1+bi)i=l+1nwifl1

上面这个式子就是我们要推导的核心!

当梯度反向传播到第 l l l层的时候,我们用上述公式计算偏导,根据链式法则,上面用大括号括起来的就是累乘项,其中前半部分是关于激活函数的导数的累乘,后半部分是关于权重值的累乘。我们知道,激活函数比如sigmoid函数,其导数的取值范围是 ( 0 , 1 4 ] (0, \frac{1}{4}] (0,41],是恒小于1的,当网络层数很深的时候,多个小于1的数进行累乘,结果是趋向于0的,也就是说此时,梯度反向传播的时候,根据参数更新公式 θ : = θ − α ⋅ ∂ l o s s θ \theta := \theta - \alpha \cdot \frac{\partial loss}{\theta} θ:=θαθloss,偏导部分的取值趋于0,那么该参数得不到更新,也就出现了我们说的梯度消失现象。

另外,我们也注意到,大括号的后半部分是关于权重值的累乘,当我们初始化权值很大的时候,多个大于1的数累乘,结果是 + ∞ +\infty +,此时就出现了梯度爆炸现象。

你可能感兴趣的:(深度学习)