RNN中梯度消失和爆炸的问题公式推导

RNN

首先来看一下经典的RRN的结构图,这里 x x x 是输入 W W W 是权重矩阵 (RNN的权重矩阵是共享的所以都是W) h h h 是隐藏状态 y y y是输出

RNN简单公式定义

h t = W ∗ f ( h t − 1 ) + W ( h x ) ∗ x [ t ] h_t = W*f(h_{t-1}) + W^{(hx)}*x_{[t]} ht=Wf(ht1)+W(hx)x[t]
y t = W ( S ) ∗ f ( h t ) y_{t} = W^{(S)}*f(h_t) yt=W(S)f(ht)
其中, h t h_t ht表示 t 时刻的隐藏状态 x [ t ] x_{[t]} x[t] 表示 t 时刻的输入 y t y_t yt 表示 t 时刻的输出。我们记总体的error为 E E E 那么 E E E 有如下表达式:
E = ∑ t = 1 T ∂ E t ∂ W E = \sum_{t=1}^{T}\frac{\partial E_t}{\partial W} E=t=1TWEt
总体的误差是所有时刻 t 的误差的累加。那么继续往下展开, 根据链式法则:
∂ E t ∂ W = ∑ k = 1 t ∂ E t ∂ y t ∂ y t ∂ h t ∂ h t ∂ h k ∂ h k ∂ W \frac{\partial E_t}{\partial W} = \sum_{k=1}^{t}\frac{\partial E_t}{\partial y_t} \frac{\partial y_t}{\partial h_t}\frac{\partial h_t}{\partial h_k} \frac{\partial h_k}{\partial W} WEt=k=1tytEthtythkhtWhk
继续往下展开有:
∂ h t ∂ h k = ∏ j = k + 1 t ∂ h j ∂ h j − 1 \frac{\partial h_t}{\partial h_k} = \prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}} hkht=j=k+1thj1hj
注意到: h t = W ∗ f ( h t − 1 ) + W ( h x ) ∗ x [ t ] h_t = W*f(h_{t-1}) + W^{(hx)}*x_{[t]} ht=Wf(ht1)+W(hx)x[t],上式的每个偏导其实是一个Jacobian式

考虑Jacobians的范数,令:
∣ ∣ ∂ h j ∂ h j − 1 ∣ ∣ ≤ ∣ ∣ W T ∣ ∣ ∗ ∣ ∣ d i a g [ f ′ ( h j − 1 ) ] ∣ ∣ ≤ β w ∗ β h ||\frac{\partial h_j}{\partial h_{j-1}} || \leq ||W^{T}|| *||diag[f'(h_{j-1})]|| \leq \beta_w*\beta_h hj1hjWTdiag[f(hj1)]βwβh
其中, β w , β h \beta_w ,\beta_h βw,βh 表示正则化的上界。将上式回代到连乘的式子得:
∣ ∣ ∂ h t ∂ h k ∣ ∣ = ∣ ∣ ∏ j = k + 1 t ∂ h j ∂ h j − 1 ∣ ∣ ≤ ( β w ∗ β h ) t − k ||\frac{\partial h_t}{\partial h_k} ||= ||\prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}|| \leq(\beta_w *\beta_h)^{t-k} hkht=j=k+1thj1hj(βwβh)tk
这里得 t 表示 time-step,也就是序列越长t会越大,即就变成了长期依赖的问题。注意到 ( β w ∗ β h ) t − k (\beta_w *\beta_h)^{t-k} (βwβh)tk 这项其实与矩阵的W的初始化有关,假设初始化一些非常小的数,W的范数也会变得很小,也就是 β w \beta_w βw会变得比较小,那么随着t的增长,这一指数项会趋近于0而导致梯度消失,相反,如果初始化成为大于1的数,则随着t的增长,会导致梯度爆炸。

你可能感兴趣的:(神经网络)