本文是根据以下三篇文章整理的LSTM推导过程,公式都源于文章,只是一些比较概念性的东西,要coding的话还要自己去吃透以下文章。


前向传播:
1、计算三个gate(in, out, forget)的输入和cell的输入:
zinj(t)=∑mwinjmym(t−1)+∑v=1SjwinjcvjScvj(t−1),(1) (1)zinj(t)=∑mwinjmym(t−1)+∑v=1SjwinjcjvScjv(t−1),
zφj(t)=∑mwφjmym(t−1)+∑v=1SjwφjcvjScvj(t−1),(2) (2)zφj(t)=∑mwφjmym(t−1)+∑v=1SjwφjcjvScjv(t−1),
zoutj(t)=∑mwoutjmym(t−1)+∑v=1SjwoutjcvjScvj(t−1),(3) (3)zoutj(t)=∑mwoutjmym(t−1)+∑v=1SjwoutjcjvScjv(t−1),
zctj(t)=∑mwctjmym(t−1)+∑v=1SjwctjcvjScvj(t−1),(4) (4)zcjt(t)=∑mwcjtmym(t−1)+∑v=1SjwcjtcjvScjv(t−1),
2、计算上述各个gate和cell的激活值:
yinj(t)=finj(zinj(t)),(5) (5)yinj(t)=finj(zinj(t)),
yφj(t)=fφj(zφj(t)),(6) (6)yφj(t)=fφj(zφj(t)),
youtj(t)=foutj(zoutj(t)),(7) (7)youtj(t)=foutj(zoutj(t)),
Scvj(0)=0,Scvj(t)=yφj(t)Scvj(t−1)+yinj(t)g(zcvj(t)),(8) (8)Scjv(0)=0,Scjv(t)=yφj(t)Scjv(t−1)+yinj(t)g(zcjv(t)),
ycvj(t)=youtjScvj(t),(9) (9)ycjv(t)=youtjScjv(t),
3、假定该网络为一个标准的三层结构(如下图所示),即一个输入层,一个隐层和一个输出层。则对于一个输出单元,我们可以按下述的方式计算它的输入和激活值。其中m为所有与该输出单元连接的单元(包括输入层的和隐层的)。

zk(t)=∑mwkmym(t),(10) (10)zk(t)=∑mwkmym(t),
yk(t)=fk(zk(t)),(11) (11)yk(t)=fk(zk(t)),
4、计算当前时间点对应状态对input gate和、forget gate以及cell的偏导数。这里跟CNN不一样,CNN前向只是求值,没有传递梯度。但对于lstm,由于内部状态的改变依赖前一时间点的状态,因此内部状态的参数也会把错误传递到网络下一层,因此前向也涉及到梯度传递。
dSjvin,m(t)=∂Scvj(t)∂winjm=tr∂Scvj(t−1)∂winjmyφj(t)+g(zcvj(t))f′inj(zinj(t))ym(t−1),(12) (12)dSin,mjv(t)=∂Scjv(t)∂winjm=tr∂Scjv(t−1)∂winjmyφj(t)+g(zcjv(t))f′inj(zinj(t))ym(t−1),
dSjvφm(t)=∂Scvj(t)∂wφjm=tr∂Scvj(t−1)∂wφjmyφj(t)+Scvj(t−1)f′φj(zφj(t))ym(t−1),(13) (13)dSφmjv(t)=∂Scjv(t)∂wφjm=tr∂Scjv(t−1)∂wφjmyφj(t)+Scjv(t−1)f′φj(zφj(t))ym(t−1),
dSjvcm(t)=∂Scvj(t)∂wcvjm=tr∂Scvj(t−1)∂wcvjmyφj(t)+g′(zcvj(t))yinj(t)ym(t−1),(14) (14)dScmjv(t)=∂Scjv(t)∂wcjvm=tr∂Scjv(t−1)∂wcjvmyφj(t)+g′(zcjv(t))yinj(t)ym(t−1),
后向传播:
1、对于每个输出单元(output unit),我们可以计算它的 输出错误如下,其中 tk(t) tk(t)为前向计算的输出, yk(t) yk(t)为真实值。
ek(t)=tk(t)−yk(t),(15) (15)ek(t)=tk(t)−yk(t),
2、接下来计算每个输出单元的残差,这里的计算和CNN是一样的,就是对该层网络求导。
δk(t)=f′k(zk)ek(t)(16) (16)δk(t)=f′k(zk)ek(t)
3、输出output gate的残差计算方式和output unit类似。(output unit只针对每一个小单元的权重,而output gate针对的是所有output unit连接到输出层的权重)
δoutj(t)=f′outj(zoutj(t))(∑Sjv=1h(Scvj(t))∑kwkcvjδk(t)),(17) (17)δoutj(t)=f′outj(zoutj(t))(∑v=1Sjh(Scjv(t))∑kwkcjvδk(t)),
4、第2和第3条针对的是外部残差,内部残差(包括input gate, forget gate和cell)计算方式如下:
eScvj(t)=youtj(t)h′(Scvj(t))(∑kwkcvjδk(t)),(18) (18)eScjv(t)=youtj(t)h′(Scjv(t))(∑kwkcjvδk(t)),
5、最后,根据残差更新各个参数(weight),注意外部和内部的表达式不一样,具体推导见原文。
1).output unit:
Δwkm(t)=αδk(t)ym(t−1),(19) (19)Δwkm(t)=αδk(t)ym(t−1),
2).output gate:
Δwout,m(t)=αδout(t)ym(t−1),(20) (20)Δwout,m(t)=αδout(t)ym(t−1),
3).input gate:
Δwin,m(t)=α∑Sjv=1eScvj(t)dSjvin,m(t),(21) (21)Δwin,m(t)=α∑v=1SjeScjv(t)dSin,mjv(t),
4).forget gate:
Δwφm(t)=α∑Sjv=1eScvj(t)dSjvφm(t),(22) (22)Δwφm(t)=α∑v=1SjeScjv(t)dSφmjv(t),
5).cell:
Δwcvjm(t)=αeScvj(t)dSjvcm(t),(23) (23)Δwcjvm(t)=αeScjv(t)dScmjv(t),