LSTM缓解梯度消失的原因

LSTM缓解梯度消失的原因

本文主要是对LSTM缓解梯度消失的原因,从公式上进行推导理解。

对于LSTM的结构如下:

LSTM缓解梯度消失的原因_第1张图片

c t = c t − 1 ⊗ σ ( W f ⋅ [ H t − 1 , X t ] ) ⊕ tanh ⁡ ( W c ⋅ [ H t − 1 , X t ] ) ⊗ σ ( W i ⋅ [ H t − 1 , X t ] ) c_{t}=c_{t-1} \otimes \sigma\left(W_{f} \cdot\left[H_{t-1}, X_{t}\right]\right) \oplus \tanh \left(W_{c} \cdot\left[H_{t-1}, X_{t}\right]\right) \otimes \sigma\left(W_{i} \cdot\left[H_{t-1}, X_{t}\right]\right) ct=ct1σ(Wf[Ht1,Xt])tanh(Wc[Ht1,Xt])σ(Wi[Ht1,Xt])

反向传播公式:

∂ E k ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ∂ C k ∂ C k − 1 … ∂ C 2 ∂ C 1 ∂ C 1 ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ( ∏ t = 2 k ∂ C t ∂ C t − 1 ) ∂ C 1 ∂ W \begin{aligned} \frac{\partial E_{k}}{\partial W}=& \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}} \frac{\partial C_{k}}{\partial C_{k-1}} \ldots \frac{\partial C_{2}}{\partial C_{1}} \frac{\partial C_{1}}{\partial W}=\\ & \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \frac{\partial C_{t}}{\partial C_{t-1}}\right) \frac{\partial C_{1}}{\partial W} \end{aligned} WEk=HkEkCkHkCk1CkC1C2WC1=HkEkCkHk(t=2kCt1Ct)WC1

括号中的部分是累乘项:

∂ c t ∂ c t − 1 = σ ( W f ⋅ [ H t − 1 , X t ] ) + d d C t − 1 ( tanh ⁡ ( W c ⋅ [ H t − 1 , X t ] ) ⊗ σ ( W i ⋅ [ H t − 1 , X t ] ) ) \frac{\partial c_{t}}{\partial c_{t-1}}=\sigma\left(W_{f} \cdot\left[H_{t-1}, X_{t}\right]\right) + \frac{d}{d \mathcal{C}_{t-1}}\left(\tanh \left(W_{c} \cdot\left[H_{t-1}, X_{t}\right]\right) \otimes \sigma\left(W_{i} \cdot\left[H_{t-1}, X_{t}\right]\right)\right) ct1ct=σ(Wf[Ht1,Xt])+dCt1d(tanh(Wc[Ht1,Xt])σ(Wi[Ht1,Xt]))

也就是说,这里的累乘单元是两项和形式,其中前部分是遗忘门的值。遗忘门决定了上一个细胞状态的保留比例,其取值可以接近于1,也就是说可以把遗忘门看成: σ ( W f ⋅ [ H t − 1 , X t ] ) ≈ 1 → \sigma\left(W_{f} \cdot\left[H_{t-1}, X_{t}\right]\right) \approx \overrightarrow{1} σ(Wf[Ht1,Xt])1 ,所以LSTM中:

∂ E k ∂ W ≈ ∂ E k ∂ H k ∂ H k ∂ c k ( Π t = 2 k σ ( W f ⋅ [ H t − 1 , X t ] ) ) ∂ C 1 ∂ w ↛ 0 \frac{\partial E_{k}}{\partial W} \approx \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial c_{k}}\left(\Pi_{t=2}^{k} \sigma\left(W_{f} \cdot\left[H_{t-1}, X_{t}\right]\right)\right) \frac{\partial C_{1}}{\partial w} \nrightarrow 0 WEkHkEkckHk(Πt=2kσ(Wf[Ht1,Xt]))wC10

所以,LSTM能缓解梯度消失。

你可能感兴趣的:(深度学习)