RNN-BPTT 笔记

我主要是参考如下文章理解的:

数学 · RNN(二)· BPTT 算法 - 知乎

其中:

RNN-BPTT 笔记_第1张图片

 这一步划红线的地方是如何理解的:

从这张图可以大概看出Lt和W的关系:

RNN-BPTT 笔记_第2张图片

我一开始单纯的理解成Lt是W的高次项函数(这里将所有激活函数当线性函数去简化推导的复杂性),然后求导,发现这样 并不是这样理解的。

此时还原RNN的结构,RNN展开后中间要经过t层。这t层,每次都要乘以W。我们可以想成第1次乘以W1,第2次乘以W2,第3次乘以W3,以此类推。只是说这里的W1=W2=W3=...=W。然后前向传播是一样的效果。

然后其实我们本应该对W1,W2,W3...单独求导,然后更新W1,W2,W3...。但是为了结构、存储量简化,这里的W1,W2,W3...的更新也同步了(都用一个W表示了)。即W1,W2,W3的更新值也一样。但是Lt对W1,W2,W3...的偏导不一样,如何更新呢?这里按照公式推导的意思来看是把每个偏导都叠加了,然后给W更新。即:

\frac{\partial L_{t}}{\partial W}=\frac{\partial L_{t}}{\partial W_1}+\frac{\partial L_{t}}{\partial W_2}+\frac{\partial L_{t}}{\partial W_3}+...

此处(W_1=W_2=W_3=...=W

然后求出来的偏导用于给W更新。

还有就是这里的推导:

RNN-BPTT 笔记_第3张图片

 

这一步的推导可以参考:

softmax函数及对数似然函数的偏导数(推导过程)_Modozil的博客-CSDN博客_softmax 对数似然

你可能感兴趣的:(机器学习,rnn,深度学习,神经网络,BPTT)