lstm的数学推导

本文是根据以下三篇文章整理的LSTM推导过程,公式都源于文章,只是一些比较概念性的东西,要coding的话还要自己去吃透以下文章。

lstm的数学推导_第1张图片

 

lstm的数学推导_第2张图片

 

前向传播:

1、计算三个gate(in, out, forget)的输入和cell的输入:

zinj(t)=mwinjmym(t1)+v=1SjwinjcvjScvj(t1),(1) (1)zinj(t)=∑mwinjmym(t−1)+∑v=1SjwinjcjvScjv(t−1),

zφj(t)=mwφjmym(t1)+v=1SjwφjcvjScvj(t1),(2) (2)zφj(t)=∑mwφjmym(t−1)+∑v=1SjwφjcjvScjv(t−1),

zoutj(t)=mwoutjmym(t1)+v=1SjwoutjcvjScvj(t1),(3) (3)zoutj(t)=∑mwoutjmym(t−1)+∑v=1SjwoutjcjvScjv(t−1),

zctj(t)=mwctjmym(t1)+v=1SjwctjcvjScvj(t1),(4) (4)zcjt(t)=∑mwcjtmym(t−1)+∑v=1SjwcjtcjvScjv(t−1),

2、计算上述各个gate和cell的激活值:

yinj(t)=finj(zinj(t)),(5) (5)yinj(t)=finj(zinj(t)),

yφj(t)=fφj(zφj(t)),(6) (6)yφj(t)=fφj(zφj(t)),

youtj(t)=foutj(zoutj(t)),(7) (7)youtj(t)=foutj(zoutj(t)),

Scvj(0)=0,Scvj(t)=yφj(t)Scvj(t1)+yinj(t)g(zcvj(t)),(8) (8)Scjv(0)=0,Scjv(t)=yφj(t)Scjv(t−1)+yinj(t)g(zcjv(t)),

ycvj(t)=youtjScvj(t),(9) (9)ycjv(t)=youtjScjv(t),

3、假定该网络为一个标准的三层结构(如下图所示),即一个输入层,一个隐层和一个输出层。则对于一个输出单元,我们可以按下述的方式计算它的输入和激活值。其中m为所有与该输出单元连接的单元(包括输入层的和隐层的)。


lstm的数学推导_第3张图片

 

 

zk(t)=mwkmym(t),(10) (10)zk(t)=∑mwkmym(t),

yk(t)=fk(zk(t)),(11) (11)yk(t)=fk(zk(t)),

4、计算当前时间点对应状态对input gate和、forget gate以及cell的偏导数。这里跟CNN不一样,CNN前向只是求值,没有传递梯度。但对于lstm,由于内部状态的改变依赖前一时间点的状态,因此内部状态的参数也会把错误传递到网络下一层,因此前向也涉及到梯度传递。

dSjvin,m(t)=Scvj(t)winjm=trScvj(t1)winjmyφj(t)+g(zcvj(t))finj(zinj(t))ym(t1),(12) (12)dSin,mjv(t)=∂Scjv(t)∂winjm=tr∂Scjv(t−1)∂winjmyφj(t)+g(zcjv(t))f′inj(zinj(t))ym(t−1),

dSjvφm(t)=Scvj(t)wφjm=trScvj(t1)wφjmyφj(t)+Scvj(t1)fφj(zφj(t))ym(t1),(13) (13)dSφmjv(t)=∂Scjv(t)∂wφjm=tr∂Scjv(t−1)∂wφjmyφj(t)+Scjv(t−1)f′φj(zφj(t))ym(t−1),

dSjvcm(t)=Scvj(t)wcvjm=trScvj(t1)wcvjmyφj(t)+g(zcvj(t))yinj(t)ym(t1),(14) (14)dScmjv(t)=∂Scjv(t)∂wcjvm=tr∂Scjv(t−1)∂wcjvmyφj(t)+g′(zcjv(t))yinj(t)ym(t−1),


后向传播:
1、对于每个输出单元(output unit),我们可以计算它的 输出错误如下,其中 tk(t) tk(t)为前向计算的输出, yk(t) yk(t)为真实值。

ek(t)=tk(t)yk(t),(15) (15)ek(t)=tk(t)−yk(t),

2、接下来计算每个输出单元的残差,这里的计算和CNN是一样的,就是对该层网络求导。

δk(t)=fk(zk)ek(t)(16) (16)δk(t)=f′k(zk)ek(t)

3、输出output gate的残差计算方式和output unit类似。(output unit只针对每一个小单元的权重,而output gate针对的是所有output unit连接到输出层的权重)

δoutj(t)=foutj(zoutj(t))(Sjv=1h(Scvj(t))kwkcvjδk(t)),(17) (17)δoutj(t)=f′outj(zoutj(t))(∑v=1Sjh(Scjv(t))∑kwkcjvδk(t)),

4、第2和第3条针对的是外部残差,内部残差(包括input gate, forget gate和cell)计算方式如下:

eScvj(t)=youtj(t)h(Scvj(t))(kwkcvjδk(t)),(18) (18)eScjv(t)=youtj(t)h′(Scjv(t))(∑kwkcjvδk(t)),

5、最后,根据残差更新各个参数(weight),注意外部和内部的表达式不一样,具体推导见原文。

1).output unit:

Δwkm(t)=αδk(t)ym(t1),(19) (19)Δwkm(t)=αδk(t)ym(t−1),

2).output gate:

Δwout,m(t)=αδout(t)ym(t1),(20) (20)Δwout,m(t)=αδout(t)ym(t−1),

3).input gate:

Δwin,m(t)=αSjv=1eScvj(t)dSjvin,m(t),(21) (21)Δwin,m(t)=α∑v=1SjeScjv(t)dSin,mjv(t),

4).forget gate:

Δwφm(t)=αSjv=1eScvj(t)dSjvφm(t),(22) (22)Δwφm(t)=α∑v=1SjeScjv(t)dSφmjv(t),

5).cell:

Δwcvjm(t)=αeScvj(t)dSjvcm(t),(23) (23)Δwcjvm(t)=αeScjv(t)dScmjv(t),

 

你可能感兴趣的:(lstm的数学推导)