通过之前的分享,我们已经了解RNN的基本结构,同时也了解了RNN存在的梯度消失的问题,这里我们通过一个例子简单回顾下RNN处理任务的逻辑及梯度消失带来的问题。
考虑用一个语言模型通过利用以前的文字信息来预测下一个文字。如果我们需要预测“the clouds are in the sky”这句话的最后一个字,我们不需要其他的信息,通过前面的语境就能知道最后一个字应该是sky。在这种情况下,相关信息与需要该信息的位置距离较近,RNN能够学习利用以前的信息来对当前任务进行相应的操作。如下图所示通过输入的 x 1 、 x 2 x_{1}、x_{2} x1、x2信息来预测出 h 3 h_{3} h3( x 0 、 h 0 x_{0}、h_{0} x0、h0一般为0)。
假设现在有个更为复杂的任务,考虑到下面这句话“I grew up in France… I speak fluent French.”,现在需要语言模型通过现有以前的文字信息预测该句话的最后一个字。通过以前文字语境可以预测出最后一个字是某种语言,但是要猜测出French,要根据之前的France语境。这样的任务,不同之前,因为这次的有用信息与需要进行处理信息的地方之间的距离较远,由于RNN存在梯度消失的问题,容易导致RNN不能学习到有用的信息,最终推导的任务可能失败。如下图所示。
LSTM可以缓解RNN梯度消失的问题。
Long Short Term Memory Networks(以下简称LSTM),一种特殊的RNN网络,该网络设计出来是为了解决长依赖问题。
所有循环神经网络都具有神经网络的重复模块链的形式。 在标准的RNN中,该重复模块将具有非常简单的结构,例如单个tanh层。标准的RNN网络如下图所示。
LSTM也具有这种链式结构,但是它的重复单元不同于标准RNN网络里的单元只有一个网络层,它的内部有四个网络层。LSTM的结构如下图所示。
对中间部分进行放大,并列上LSTM涉及到的计算公式,其中乘法( ∗ * ∗)是点乘( ⊙ \odot ⊙),如下图。
LSTM的核心是细胞状态,用贯穿细胞的水平线表示。
细胞状态像传送带一样。它贯穿整个细胞却只有很少的分支,这样能保证信息不变的流过整个RNN。细胞状态如下图所示。
相比RNN只有一个传递状态 h t h_t ht,LSTM有两个传输状态,一个 C t C_t Ct(cell state),和一个 h t h_t ht(hidden state)
即在 t 时刻:
LSTM 的输入有三个:当前时刻网络的输入值 x t x_t xt、上一时刻 LSTM 的输出值 h t − 1 h_{t-1} ht−1、以及上一时刻的单元状态 C t − 1 C_{t-1} Ct−1;
LSTM 的输出有两个:当前时刻 LSTM 输出值 h t h_t ht、和当前时刻的单元状态 C t C_t Ct。
其中对于传递下去的 C t C_t Ct改变得很慢,通常输出的 C t C_t Ct是上一个状态传过来的 C t − 1 C_{t-1} Ct−1加上一些数值。而 h t h_t ht则在不同节点下往往会有很大的区别。
LSTM网络能通过一种被称为门的结构对细胞状态进行删除或者添加信息。
门能够有选择性的决定让哪些信息通过。其实门的结构很简单,主要是由sigmoid函数进行控制。
因为sigmoid层的输出是0-1的值,这代表有多少信息能够流过sigmoid层。0表示都不能通过,1表示都能通过。
一个LSTM里面包含三个门来控制细胞状态,这三个门分别称为忘记门、输入门和输出门。下面一个一个的来讲述。
LSTM的第一步就是决定细胞状态需要丢弃哪些信息。这部分操作是通过一个称为忘记门的sigmoid单元来处理的。它通过查看 h t − 1 h_{t-1} ht−1和 x t x_{t} xt信息来输出一个0-1之间的向量,该向量里面的0-1值表示细胞状态 C t − 1 C_{t-1} Ct−1中的哪些信息保留或丢弃多少。0表示不保留,1表示都保留。忘记门如下图所示。
下一步是决定给细胞状态添加哪些新的信息。这一步又分为两个步骤,首先,利用 h t − 1 h_{t-1} ht−1和 x t x_{t} xt通过一个称为输入门的操作来决定更新哪些信息。然后利用 h t − 1 h_{t-1} ht−1和 x t x_{t} xt通过一个tanh层得到新的候选细胞信息 C ~ t \tilde C_{t} C~t,这些信息可能会被更新到细胞信息中。这两步描述如下图所示。
下面将更新旧的细胞信息 C t − 1 C_{t-1} Ct−1,变为新的细胞信息 C t C_{t} Ct。更新的规则就是通过忘记门选择忘记旧细胞信息的一部分,通过输入门选择添加候选细胞信息 C ~ t \tilde C_{t} C~t的一部分得到新的细胞信息 C t C_{t} Ct。更新操作如下图所示。
更新完细胞状态后需要根据输入的 h t − 1 h_{t-1} ht−1和 x t x_{t} xt来判断输出细胞的哪些状态特征,这里需要将输入经过一个称为输出门的sigmoid层得到判断条件,然后将细胞状态经过tanh层得到一个(-1, 1)之间值的向量,该向量与输出门得到的判断条件相乘就得到了最终该RNN单元的输出。该步骤如下图所示
最后的输出 y t y_{t} yt是由 h t h_{t} ht得来,分类问题可类似RNN经过非线性变换如softmax得来( y t = s o f t m a x ( W h y h t + b y ) y_{t}=softmax(W_{hy}h_{t}+b_{y}) yt=softmax(Whyht+by)),预测问题可直接输出。
这里还有一点需要注意,图中每一个门发挥作用的乘法( ∗ * ∗)是点乘( ⊙ \odot ⊙),即两个向量对应位置的元素相乘。
到此为止,我们介绍了LSTM的整体结构,并对三个门如何处理输入的进行了详细介绍,下面我们以一个例子来说明各个参数以及输入、输出数据的维度。
假设输入 x t ∈ R 5 x_{t}\in\mathbb{R}^{5} xt∈R5, h t ∈ R 3 h_{t}\in\mathbb{R}^{3} ht∈R3,则 C t ∈ R 3 C_{t}\in\mathbb{R}^{3} Ct∈R3,
f t = σ ( W f ∗ [ h t − 1 , x t ] + b f ) W f ∈ R 3 × 8 , b f ∈ R 3 f_{t}=\sigma(W_{f}*[h_{t-1},x_{t}]+b_{f})\ \ \ \ \ \ W_{f}\in\mathbb{R}^{3\times8}, b_{f}\in\mathbb{R}^{3} ft=σ(Wf∗[ht−1,xt]+bf) Wf∈R3×8,bf∈R3
i t = σ ( W i ∗ [ h t − 1 , x t ] + b i ) W i ∈ R 3 × 8 , b i ∈ R 3 i_{t}=\sigma(W_{i}*[h_{t-1},x_{t}]+b_{i})\ \ \ \ \ \ W_{i}\in\mathbb{R}^{3\times8}, b_{i}\in\mathbb{R}^{3} it=σ(Wi∗[ht−1,xt]+bi) Wi∈R3×8,bi∈R3
C ~ t = t a n h ( W C ∗ [ h t − 1 , x t ] + b C ) W C ∈ R 3 × 8 , b C ∈ R 3 \tilde{C}_{t}=tanh(W_{C}*[h_{t-1},x_{t}]+b_{C})\ \ \ \ \ \ W_{C}\in\mathbb{R}^{3\times8}, b_{C}\in\mathbb{R}^{3} C~t=tanh(WC∗[ht−1,xt]+bC) WC∈R3×8,bC∈R3
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_{t}=f_{t}\odot C_{t-1}+i_{t}\odot\tilde{C}_{t} Ct=ft⊙Ct−1+it⊙C~t
o t = σ ( W o ∗ [ h t − 1 , x t ] + b o ) W o ∈ R 3 × 8 , b o ∈ R 3 o_{t}=\sigma(W_{o}*[h_{t-1},x_{t}]+b_{o})\ \ \ \ \ \ W_{o}\in\mathbb{R}^{3\times8}, b_{o}\in\mathbb{R}^{3} ot=σ(Wo∗[ht−1,xt]+bo) Wo∈R3×8,bo∈R3
h t = o t ⊙ t a n h ( C t ) h_{t}=o_{t}\odot tanh(C_{t}) ht=ot⊙tanh(Ct)
LSTM更为显著的变式是在2014年提出的门循环单元(Gate Recurrent Unit,简称GRU)。它将忘记门和输入门合并成一个新的门,称为更新门。GRU还有一个门称为重置门。如下图所示。(这里省略了偏置项 b b b)
其中 z t z_{t} zt是重置门, r t r_{t} rt是更新门,两者都是(0, 1)之间的向量,图中门发挥作用的乘法( ∗ * ∗)也是点乘( ⊙ \odot ⊙)。
通过重置门控来得到“重置”之后的数据 h t − 1 ′ = r t ⊙ h t − 1 h_{t-1}^{'}=r_{t}\odot h_{t-1} ht−1′=rt⊙ht−1,再将 h t − 1 ′ h_{t-1}^{'} ht−1′与输入 x t x_{t} xt进行拼接,最后通过一个tanh激活函数来将数据放缩到(-1, 1)的范围内。
更新表达式: h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1-z_{t}) \odot h_{t-1} + z_{t}\odot\tilde{h}_{t} ht=(1−zt)⊙ht−1+zt⊙h~t
首先再次强调一下,门控信号(这里的 z t z_{t} zt)的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。
GRU很聪明的一点就在于,我们使用了同一个门控 z t z_{t} zt就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。
( 1 − z t ) ⊙ h t − 1 (1-z_{t}) \odot h_{t-1} (1−zt)⊙ht−1:表示对原本隐藏状态的选择性“遗忘”。这里的 1 − z t 1-z_{t} 1−zt可以想象成遗忘门(forget gate),忘记 h t − 1 h_{t-1} ht−1维度中一些不重要的信息。
z t ⊙ h ~ t z_{t}\odot\tilde{h}_{t} zt⊙h~t: 表示对包含当前节点信息的 h ~ t \tilde{h}_{t} h~t进行选择性”记忆“。与上面类似,这里的 z t z_{t} zt同理会忘记 h ~ t \tilde{h}_{t} h~t维度中的一些不重要的信息。或者,这里我们更应当看做是对 h ~ t \tilde{h}_{t} h~t维度中的某些信息进行选择。
h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1-z_{t}) \odot h_{t-1} + z_{t}\odot\tilde{h}_{t} ht=(1−zt)⊙ht−1+zt⊙h~t:结合上述,这一步的操作就是忘记传递下来的 h t − 1 h_{t-1} ht−1中的某些维度信息,并加入当前节点输入的某些维度信息。
总结为一句话,重置门决定了如何把新的输入与之前的记忆相结合,更新门决定多少先前的记忆起作用。如果我们把所有重置门设置为全1,更新门设置为全0,又达到了普通RNN的形式。
GRU与LSTM对比来看,一方面GRU的参数更少,因而训练稍快或需要更少的数据来泛化。另一方面,如果你有足够的数据,LSTM的强大表达能力可能会产生更好的结果。
LSTM的出现缓解了RNN梯度消失的痛点,这里我们将对LSTM如何缓解梯度消失的做一个讲解。
首先我们给出LSTM的计算公式。
f t = σ ( W f ∗ [ h t − 1 , x t ] + b f ) f_{t}=\sigma(W_{f}*[h_{t-1},x_{t}]+b_{f}) ft=σ(Wf∗[ht−1,xt]+bf)
i t = σ ( W i ∗ [ h t − 1 , x t ] + b i ) i_{t}=\sigma(W_{i}*[h_{t-1},x_{t}]+b_{i}) it=σ(Wi∗[ht−1,xt]+bi)
C ~ t = t a n h ( W C ∗ [ h t − 1 , x t ] + b C ) \tilde{C}_{t}=tanh(W_{C}*[h_{t-1},x_{t}]+b_{C}) C~t=tanh(WC∗[ht−1,xt]+bC)
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_{t}=f_{t}\odot C_{t-1}+i_{t}\odot\tilde{C}_{t} Ct=ft⊙Ct−1+it⊙C~t
o t = σ ( W o ∗ [ h t − 1 , x t ] + b o ) o_{t}=\sigma(W_{o}*[h_{t-1},x_{t}]+b_{o}) ot=σ(Wo∗[ht−1,xt]+bo)
h t = o t ⊙ t a n h ( C t ) h_{t}=o_{t}\odot tanh(C_{t}) ht=ot⊙tanh(Ct)
我们之前的文章介绍过,RNN梯度消失的原因是由于在链式求导过程中,求导项中包含 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}}=\prod_{j=k+1}^{t}{tanh^{'}}W_{s} ∏j=k+1t∂Sj−1∂Sj=∏j=k+1ttanh′Ws,由于激活函数tanh的导数是小于1的,因此随着累乘的增加,RNN会出现梯度消失的情况。
对于LSTM这种结构来说,其中最关键的就是cell state的传播流程,大部分文章说因为cell state的传播是靠加法的,所以有效抑制了梯度消失,这是不准确的。
cell state的传播公式在1997年版本的LSTM是这样的: C t = C t − 1 + i C ~ t C_{t}=C_{t-1}+i \tilde{C}_{t} Ct=Ct−1+iC~t
是没有遗忘门的,如果在这个版本说是因为加法有效抑制了梯度消失,还有一定道理,但很多人存在一个误解:97年版本的cell state的求导导数为1,梯度可以恒定传播,很多人忽略了后面 i C ~ t i \tilde{C}_{t} iC~t。不过对于这个版本的LSTM的代码来说,cell state反向传播导数确实为1,以为梯度截断去掉了后面那部分的影响。
但对于1997年版本的LSTM来说,即使考虑了后面那部分,导数依然不会小于1,梯度消失现象确实也就不会发生,但为什么好端端的后来就加了个遗忘门呢?
原因是cell state不能只进不出,当序列过长的时候,cell state后面会变成庞然大物,反而影响模型的效果,所以后来加入了遗忘门。加入遗忘门这个操作,可以说是更容易让LSTM产生梯度消失了,但相比遗忘门带来的收益,这点儿损失不算什么。
但是现在的LSTM在缓解梯度消失问题上的表现也是非常不错了,其原因还是在于BPTT((backpropagation though time))的过程中,接下来对这部分进行详细解释。
在对参数矩阵求导时,有多条求导路径,最后将这些求导路径相加得到最终的梯度,这里我们只关注 C t − 1 → C t = f t ⊙ C t − 1 + i t ⊙ C t ^ C_{t-1} \rightarrow C_t = f_t\odot C_{t-1} + i_t \odot \hat{C_t} Ct−1→Ct=ft⊙Ct−1+it⊙Ct^这条梯度路径,在这条路径中,类似于传统RNN的 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}} ∏j=k+1t∂Sj−1∂Sj,我们主要关注 ∂ C t ∂ C t − 1 \frac{\partial C_{t}}{\partial C_{t-1}} ∂Ct−1∂Ct的值,具体来说,
∂ C t ∂ C t − 1 = ∂ C t ∂ f t ∂ f t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ i t ∂ i t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C t ~ ∂ C t ~ ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C t − 1 \frac{\partial C_{t}}{\partial C_{t-1}}=\frac{\partial C_{t}}{\partial f_{t}}\frac{\partial f_{t}}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial i_{t}}\frac{\partial i_{t}}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial \tilde{C_{t}}}\frac{\partial \tilde{C_{t}}}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}} ∂Ct−1∂Ct=∂ft∂Ct∂ht−1∂ft∂Ct−1∂ht−1+∂it∂Ct∂ht−1∂it∂Ct−1∂ht−1+∂Ct~∂Ct∂ht−1∂Ct~∂Ct−1∂ht−1+∂Ct−1∂Ct
进一步简化,
∂ C t ∂ C t − 1 = C t − 1 ⊙ σ ′ ( . ) W f ∗ ( o t − 1 ⊙ t a n h ′ ( C t − 1 ) ) + C t ~ ⊙ σ ′ ( . ) W i ∗ ( o t − 1 ⊙ t a n h ′ ( C t − 1 ) ) + i t ⊙ t a n h ′ ( . ) W c ∗ ( o t − 1 ⊙ t a n h ′ ( C t − 1 ) ) + f t \frac{\partial C_{t}}{\partial C_{t-1}}=C_{t-1}\odot\sigma^{'}(.)W_{f}*(o_{t-1}\odot tanh^{'}(C_{t-1}))+\tilde{C_{t}}\odot\sigma^{'}(.)W_{i}*(o_{t-1}\odot tanh^{'}(C_{t-1}))+i_{t}\odot tanh^{'}(.)W_{c}*(o_{t-1}\odot tanh^{'}(C_{t-1}))+f_{t} ∂Ct−1∂Ct=Ct−1⊙σ′(.)Wf∗(ot−1⊙tanh′(Ct−1))+Ct~⊙σ′(.)Wi∗(ot−1⊙tanh′(Ct−1))+it⊙tanh′(.)Wc∗(ot−1⊙tanh′(Ct−1))+ft
这里给出一个粗略的化简形式,有兴趣的同学可以更加细化些,这里不过多赘述。
现在,如果我们要计算时刻k的,简单的利用上式进行乘t-k+1次即可。递归梯度计算上,LSTM与原始RNN最大的不同之处在于,在原始RNN中 ∂ S j ∂ S j − 1 \frac{\partial{S_{j}}}{\partial{S_{j-1}}} ∂Sj−1∂Sj会一直大于1或者在[0,1]区间,这会导致梯度爆炸或者消失。在 ∂ C t ∂ C t − 1 \frac{\partial C_{t}}{\partial C_{t-1}} ∂Ct−1∂Ct中,我们注意最后一项 f t f_{t} ft,即遗忘门,我们可以通过控制 f t f_{t} ft的大小来控制 ∂ C t ∂ C t − 1 \frac{\partial C_{t}}{\partial C_{t-1}} ∂Ct−1∂Ct,例如可以增大遗忘门 f t f_{t} ft,使得它能把 ∂ C t ∂ C t − 1 \frac{\partial C_{t}}{\partial C_{t-1}} ∂Ct−1∂Ct的值拉向1,这就可以减缓梯度消失(也就是梯度不会太快的消失)。因此 ∂ C t ∂ C t − 1 \frac{\partial C_{t}}{\partial C_{t-1}} ∂Ct−1∂Ct结果的取值范围并不一定局限在[0,1]中,而是有可能大于1的。另外一个重要的地方是 f t , i t , o t , C t ~ f_{t},i_{t},o_{t},\tilde{C_{t}} ft,it,ot,Ct~都是LSTM自己学习到的。所以说,LSTM会通过学习改变门控的值来决定什么时候遗忘梯度,什么时候保留梯度,即依靠学习得到权值去控制依赖的长度。
总结起来,LSTM能缓解梯度消失,其实主要是以下两点的结果:
1.cell状态的加法更新策略使得梯度传递更恰当,使得梯度更新有可能大于1。
2.门控单元可以决定遗忘多少梯度,他们可以在不同的时刻取不同的值。这些值都是通过隐层状态和输入的数据学习到的。
最后,LSTM依然不能完全解决梯度消失这个问题,有文献表示序列长度一般到了三百多仍然会出现梯度消失现象。如果想彻底规避这个问题,还是transformer好用。transformer的具体结构我们在其他文章中进行介绍。
针对LSTM缓解梯度下降的原因,这篇文章介绍的很详细:https://weberna.github.io/blog/2017/11/15/LSTM-Vanishing-Gradients.html