RNN里的BPTT算法

这两天对RNN循环神经网络进行了学习,由一无所知到现在对什么是RNN以及它的前向传播和反向传播有了认识,尤其是BPTT算法的推导有些繁琐,但是推过一次后,对RNN反向传播求梯度的过程有了更清晰的认识。

下面是朴素的RNN循环神经网络图。(图1)
RNN里的BPTT算法_第1张图片

我在写博客前,自己先手写了一份推导过程。(图2)
RNN里的BPTT算法_第2张图片

为何BPTT更难?

因为多了状态之间的传递(即隐层单元之间的“交流”),根据前向传播算法,我们知道 s t ∗ = W s t − 1 + U x t , s_t^* = Ws_{t-1} + Ux_t , st=Wst1+Uxt, s t − 1 = f ( s t − 1 ∗ ) = f ( W s t − 2 + U x t − 1 ) s_{t-1} = f(s_{t-1}^*) = f(Ws_{t-2}+Ux_{t-1}) st1=f(st1)=f(Wst2+Uxt1),这说明 s t − 1 s_{t-1} st1也是关于 W W W的式子。

这样层层嵌套下去…就会追溯到 s 0 s_0 s0。可以意识到我们对 W 、 U W、U WU的梯度求解是繁琐的,而这正是BPTT的难点所在。对于 V V V的梯度求解,并没有受到状态之间传递的影响,因此和我们BP算法求解方式是一样的。

我们用 ∗ * 表示element-wise, × × ×表示矩阵乘法。
我们采用交叉熵损失函数,即 L t = − ( y t l o g ( o t ) + ( 1 − y t ) l o g ( 1 − o t ) ) L_t = - (y_tlog(o_t)+(1-y_t)log(1-o_t)) Lt=(ytlog(ot)+(1yt)log(1ot))
我们定义隐藏层的激活函数为sigmoid函数 s t = f ( s t ∗ ) s_t = f(s_t^*) st=f(st),输出层的激活函数也为sigmoid函数 o t = g ( o t ∗ ) o_t = g(o_t^*) ot=g(ot) f ′ = s t ∗ ( 1 − s t ) , g ′ = o t ∗ ( 1 − o t ) f' = s_t*(1-s_t), g' = o_t*(1-o_t) f=st(1st),g=ot(1ot) 。具体求导读者自行证明。

由前向传播可知, o t = g ( o t ∗ ) = g ( V s t ) o_t = g(o_t^*)=g(Vs_t) ot=g(ot)=g(Vst)

那么 ∂ L t ∂ V = ∂ L t ∂ o t ∗ ∂ o t ∂ o t ∗ ⋅ ∂ o t ∗ ∂ V = − ( y t o t + y t − 1 1 − o t ) ∗ o t ∗ ( 1 − o t ) ⋅ ∂ o t ∗ ∂ V = ( o t − y t ) × s t T \frac{\partial L_t}{\partial V} = \frac{\partial L_t}{\partial o_t}* \frac{\partial o_t}{\partial o_t^*}·\frac{\partial o_t^*}{\partial V} = -(\frac{y_t}{o_t}+\frac{y_t-1}{1-o_t})*o_t*(1-o_t)·\frac{\partial o_t^*}{\partial V} = (o_t-y_t)×s_t^ \mathrm{ T } VLt=otLtototVot=(otyt+1otyt1)ot(1ot)Vot=(otyt)×stT

不同时刻的 ∂ L t ∂ V \frac{\partial L_t}{\partial V} VLt要相加,得到最后的 ∂ L ∂ V \frac{\partial L}{\partial V} VL

由前向传播可知,对于时刻t而言, s t − 1 s_{t-1} st1也是关于 W W W的式子,因此我们在求 ∂ L t ∂ W \frac{\partial L_t}{\partial W} WLt时,不能简单的将 s t − 1 s_{t-1} st1视为常量,因此 ∂ L t ∂ W = ∑ k = 0 t ∂ L t ∂ s k ∗ × s k − 1 T \frac{\partial L_t}{\partial W} = \sum_{k=0}^t \frac{\partial L_t}{\partial s_k^*}×s_{k-1}^ \mathrm{ T } WLt=k=0tskLt×sk1T(注意,在我这里是把第一个时刻从0开始)。

∂ L t ∂ s t ∗ = ∂ L t ∂ o t ∗ ⋅ ∂ o t ∗ ∂ s t ∗ = V T × ( o t − y t ) ∗ s t ∗ ( 1 − s t ) \frac{\partial L_t}{\partial s_t^*} = \frac{\partial L_t}{\partial o_t^*}· \frac{\partial o_t^*}{\partial s_t^*}= V^\mathrm{T}×(o_t-y_t)*s_t*(1-s_t) stLt=otLtstot=VT×(otyt)st(1st)
∂ L t ∂ s k − 1 ∗ = ∂ L t ∂ s k ∗ ⋅ ∂ s k ∗ ∂ s k − 1 ∗ ∂ s k − 1 ∂ s k − 1 ∗ = s k − 1 ∗ ( 1 − s k − 1 ) ∗ W T × ∂ L t ∂ s k ∗ ( k = 1 , 2 , 3... t ) \frac{\partial L_t}{\partial s_{k-1}^*} = \frac{\partial L_t}{\partial s_k^*}· \frac{\partial s_k^*}{\partial s_{k-1}}*\frac{\partial s_{k-1}}{\partial s_{k-1}^*}= s_{k-1}*(1-s_{k-1})*W^\mathrm{T}×\frac{\partial L_t}{\partial s_k^*} (k=1,2,3...t) sk1Lt=skLtsk1sksk1sk1=sk1(1sk1)WT×skLt(k=1,2,3...t)

同理, ∂ L t ∂ U = ∑ k = 0 t ∂ L t ∂ s k ∗ × x k T \frac{\partial L_t}{\partial U} = \sum_{k=0}^t \frac{\partial L_t}{\partial s_k^*}×x_{k}^ \mathrm{ T } ULt=k=0tskLt×xkT

最后不同时刻的 ∂ L t ∂ U \frac{\partial L_t}{\partial U} ULt要相加,得到最终的 ∂ L ∂ U = ∑ t = 0 t = n ∂ L t ∂ U \frac{\partial L}{\partial U}= \sum_{t=0}^{t=n} \frac{\partial L_t}{\partial U} UL=t=0t=nULt

最后不同时刻的 ∂ L t ∂ W \frac{\partial L_t}{\partial W} WLt要相加,得到最终的 ∂ L ∂ W = ∑ t = 0 t = n ∂ L t ∂ W \frac{\partial L}{\partial W}= \sum_{t=0}^{t=n} \frac{\partial L_t}{\partial W} WL=t=0t=nWLt

[1] https://blog.csdn.net/zhaojc1995/article/details/80572098

你可能感兴趣的:(反向传播算法)