RNN中的梯度消失/爆炸原因

RNN中的梯度消失/爆炸原因

梯度消失/梯度爆炸是深度学习中老生常谈的话题,这篇博客主要是对RNN中的梯度消失/梯度爆炸原因进行公式层面上的直观理解。

RNN中的梯度消失/爆炸原因_第1张图片

首先,上图是RNN的网络结构图, ( x 1 , x 2 , x 3 , … , ) (x_1, x_2, x_3, …, ) (x1,x2,x3,,)是输入的序列, X t X_t Xt表示时间步为 t t t时的输入向量。假设我们总共有 k k k个时间步,用第 k k k个时间步的输出 H k H_k Hk作为输出(实际上每个时间步都有输出,这里仅考虑 H k H_k Hk),用 E k E_k Ek表示损失。

其中, C t = tanh ⁡ ( W c C t − 1 + W x X t ) C_{t}=\tanh \left(W_{c} C_{t-1}+W_{x} X_{t}\right) Ct=tanh(WcCt1+WxXt)

从上式可以看出 W x W_x Wx W c W_c Wc其实是差不多的,记 W = [ W c , W x ] W=[W_c, W_x] W=[Wc,Wx],那么求偏导可以得到:

∂ E k ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ∂ C k ∂ C k − 1 … ∂ C 2 ∂ C 1 ∂ C 1 ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ( ∏ t = 2 k ∂ C t ∂ C t − 1 ) ∂ C 1 ∂ W \begin{aligned} \frac{\partial E_{k}}{\partial W}=& \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}} \frac{\partial C_{k}}{\partial C_{k-1}} \ldots \frac{\partial C_{2}}{\partial C_{1}} \frac{\partial C_{1}}{\partial W}=\\ & \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \frac{\partial C_{t}}{\partial C_{t-1}}\right) \frac{\partial C_{1}}{\partial W} \end{aligned} WEk=HkEkCkHkCk1CkC1C2WC1=HkEkCkHk(t=2kCt1Ct)WC1

其中的累乘部分为:

∂ C t ∂ c t − 1 = tanh ⁡ ′ ( W c C t − 1 + W x X t ) ⋅ d d C t − 1 [ W c C t − 1 + W x X t ] = tanh ⁡ ′ ( W c C t − 1 + W x X t ) ⋅ W c \begin{aligned} \frac{\partial C_{t}}{\partial c_{t-1}}=& \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot \frac{d}{d C_{t-1}}\left[W_{c} C_{t-1}+W_{x} X_{t}\right]=\\ & \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot W_{c} \end{aligned} ct1Ct=tanh(WcCt1+WxXt)dCt1d[WcCt1+WxXt]=tanh(WcCt1+WxXt)Wc

将该式代入上式有:

∂ E k ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ( ∏ t = 2 k tanh ⁡ ′ ( W c C t − 1 + W x X t ) ⋅ W c ) ∂ c 1 ∂ W \frac{\partial E_{k}}{\partial W}=\frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot W_{c}\right) \frac{\partial c_{1}}{\partial W} WEk=HkEkCkHk(t=2ktanh(WcCt1+WxXt)Wc)Wc1

观察这个式子,和上篇文章中一样,因为链式法则,出现了累乘项,因为tanh的导数 <= 1,所以,当k很大的时候,上式的值是趋向于0的。(<1的数多次相乘),也就是:

Π t = 2 k tanh ⁡ ′ ( W c C t − 1 + w x X t ) ⋅ W c → 0 , \Pi_{t=2}^{k} \tanh ^{\prime}\left(W_{c} C_{t-1}+w_{x} X_{t}\right) \cdot W_{c} \rightarrow 0, Πt=2ktanh(WcCt1+wxXt)Wc0, so ∂ E k ∂ W → 0 \frac{\partial E_{k}}{\partial W} \rightarrow 0 WEk0

此时,权重更新公式:

W ← W − α ∂ E k ∂ W ≈ W W \leftarrow W-\alpha \frac{\partial E_{k}}{\partial W} \approx W WWαWEkW

也就是说,RNN很容易出现梯度消失现象,使得参数更新缓慢,甚至是停止更新。

你可能感兴趣的:(深度学习)