lstm结构图_深入理解RNN与LSTM

lstm结构图_深入理解RNN与LSTM_第1张图片

循环神经网络(Recurrent Neural Network)基础

在深度学习领域,神经网络已经被用于处理各类数据,如CNN在图像领域的应用,全连接神经网络在分类问题的应用等。随着神经网络在各个领域的渗透,传统以统计机器学习为主的NLP问题,也逐渐开始采用深度学习的方法来解决。如由Google Brain提出的Word2Vec模型,便将传统BoW等统计方法的词向量方法,带入到了以深度学习为基础的Distribution Representation的方法中来,真正地将NLP问题带入了深度学习的练兵场。当然,RNN的模型并非局限于NLP领域,而是为了解决一系列序列化数据的建模问题,如视频、语音等,而文本也只是序列化数据的一种典型案例。

RNN的特征在于,对于每个RNN神经元,其参数始终共享,即对于文本序列,任何一个输入都经过相同的处理,得到一个输出。在传统的全连接神经网络的结构中,神经元之间互不影响,并没有直接联系,神经元与神经元之间相互独立。而在RNN结构中,隐藏层的神经元开始通过一个隐藏状态所相连,通常会被表示为

。在理解RNN与全连接神经网络时,需要对两者的结构加以区分,通常,FCN会采用水平方式进行可视化理解,即每一层的神经元垂直排列,而不同层之间以水平方式排布。但在RNN的模型图中,隐藏层的不同神经元之间通常水平排列,而隐藏层的不同层之间以垂直方式排列,如图所示,在FCN网络中,各层水平布局,隐藏层各神经元相互独立,在RNN中,各层以垂直布局,而水平方向上布局着各神经元。
注意:RNN结构图只是为了使得结构直观易理解,而在水平方向上其实每个A都相同,对于每个时间步其都是采用同一个神经元进行前向传播。

lstm结构图_深入理解RNN与LSTM_第2张图片
FCN、RNN对比

RNN的前向传播

在RNN中,序列数据按照其时间顺序,依次输入到网络中,而时间顺序则表示时间步的概念。在RNN中,隐藏状态极为重要,隐藏状态是连接各隐藏层各神经元的中介值。如上图,在第一层中,在时间步

,RNN隐藏层神经元得到隐藏状态
,在时间步
,则接受来自上一个时间步的隐藏层输出
,得到新的隐藏状态
。而从垂直方向上看,各层之间,也通过隐藏状态所连接,对于
在水平的时间轴上,各神经元通过隐藏状态
连接,而层间还将接受前一层的
的值来作为
的值,从而获得到该层新的隐藏状态。因此,
RNN是一个在水平方向和垂直方向上,均可扩展的结构(水平方向上只是人为添加的易于理解的状态,在工程实践中不存在水平方向的设置)。

根据RNN的定义,可以简单地给出RNN的前向传播过程:

如上式,对于某一层,

均为模型需要学习的参数,通过上图RNN结构图的对应,则应为
层水平方向所有神经元的参数,
同一层的RNN单元参数相同,即参数共享。若考虑多层RNN,则可将上式改为:

为了简化研究,下文统一对单层RNN进行讨论。

值得注意的是,单层RNN前向传播可做如下变换:

为此,我们不妨将参数进行统一表示:

,其中
表示拼接操作,则前向传播变为

再获得隐藏状态后,若需要获得每一个时间步的输出,则需要进一步进行线性变换:

,其中
为参数,
为激活函数,如

lstm结构图_深入理解RNN与LSTM_第3张图片
RNN单元-Folded

针对单层RNN,可采用上述结构进行描述。

RNN的反向传播

为简化分析,选用RNN的最后时间步的隐藏状态(无输出层)直接作为输出层,即

,若为分类问题,则
通常为
。定义问题的损失函数为
,则在进行反向传播时,需要计算
的梯度,可进行如下推导:

然而,在RNN的反向传播中,不仅需要根据垂直方向进行梯度推导,同时需要根据水平方向,按照时间步进行梯度推导,即RNN中的BPTT(Back Propagation Through Time)反向传播。从公式中也可以看出,在前向传播过程中

是关于
的函数值,即
,则
可以进一步进行微分,于是将
关于
求偏导,以循着时间轴更新
时刻的

根据反向传播的规则,每个在当前时间步

应向前追溯直到
,计算梯度并更新参数,而在RNN中时间步中的
参数被所有步共享,因此梯度是对同一个参数计算,为此可以将梯度作求和,一次性更新至
,如图每个箭头表示一次梯度计算,则在
时刻计算梯度时,不仅需要直接计算当前时刻的梯度,还仍需根据时间轴,分别计算
时刻的梯度。

注:本推导在假设RNN仅使用一个输出,即最后一个时间步的输出为最终输出,而RNN在每个时间步均有输出,若考虑多个输出,则损失函数不同,即损失为各时间步损失的总和,而在计算梯度时,需要对每个时间步输出均计算一个输出,即

.

lstm结构图_深入理解RNN与LSTM_第4张图片
BPTT过程

则在

的过程中,
更新的梯度为

对于偏置

采用相同方式推导,此处不再重复推导。

注意:此处和后文若无特殊说明,均只讨论单层RNN,多层RNN则将RNN单元视为FCN中层即可。

RNN的梯度弥散与爆炸

根据上节的推导,可知,在进行BPTT时,RNN单元的反向传播梯度如下:

若激活函数

采用
,图像如图:

lstm结构图_深入理解RNN与LSTM_第5张图片
sigmoid与tanh函数图像

对激活函数求导,当

时,

时,

导数图像如下图所示:

lstm结构图_深入理解RNN与LSTM_第6张图片
sigmoid与tanh导数

从图像可以看出,在激活函数的两端,导数均介接近于0,根据上述RNN梯度的推导,假设当前处于最后一个时间步

,则在向前BPTT时,会得出
的计算,当
值接近于两端时,则其梯度异常接近于0,并且
导数最大值才为
,多个接近于0的数相乘,将导致梯度呈指数下降趋势,接近于0,导致梯度弥散。随着序列的变长,
的值越小,这便说明,
RNN不具备长期记忆,而只具备短期记忆。
由于梯度弥散,导致在序列长度很长时,无法在较后的时间步中,按照梯度更新较前时间步的
,导致无法根据后续序列来修改前向序列的参数,使得前向序列无法很好地做特征提取,
使得在长时间步过后,模型将无法再获取有效的前向序列记忆信息

梯度弥散,在RNN属于重要问题,为此便提出了以LSTM、GRU等结构的变种,来解决RNN短期记忆的瓶颈。同样的,根据上述梯度的推导,梯度中

将会导致参数累乘,若初始参数较大时,则较大数相乘,将导致
梯度爆炸,然而梯度爆炸相对于梯度弥散较容易解决,通常加入梯度裁剪即可一定程度缓解。

长短期记忆网络(Long Short Term Memory)

前面说到,RNN单元在面对长序列数据时,很容易便遭遇梯度弥散,使得RNN只具备短期记忆,即RNN面对长序列数据,仅可获取较近的序列的信息,而对较早期的序列不具备记忆功能,从而丢失信息。为此,为解决该类问题,便提出了LSTM结构,其核心关键在于:

  1. 提出了门机制:遗忘门、输入门、输出门
  2. 细胞状态:在RNN中只有隐藏状态的传播,而在LSTM中,引入了细胞状态。

LSTM的前向传播

如下图,为三个LSTM单元的连结,其中相较于传统RNN单元,其多了上下两条轴,分别用于承载细胞状态

及隐藏状态
的信息流动,而其中
则被称为门,通过乘运算于和运算实现数据的合并于过滤。

为更好地比较LSTM与RNN的区别,再此将RNN前向传播记录如下:

lstm结构图_深入理解RNN与LSTM_第7张图片
LSTM整体结构

紧接着,对LSTM的门进行定义,其均为:

其中,

分别表示遗忘门、输入门、输出门,对应地,
在不同门中,也应为不同的参数。为此,可卸除LSTM详细的前向传播过程。

lstm结构图_深入理解RNN与LSTM_第8张图片
LSTM单元

如图中各

,则表示各门,其与
运算做到了信息过滤和叠加。

在遗忘门:

,由之前所介绍的
函数可知,其函数值在
范围内。这里可以思考一下计算机中,门电路的思想,在逻辑电路中,分为“与门”,“或门”,“非门”等,对于“与门”,只有当两者均为1时结果为1,同样地对于遗忘门的运算,其输出值为
,当进行乘法运算时,是否也能达到信息过滤的效果呢?

lstm结构图_深入理解RNN与LSTM_第9张图片
遗忘门(高亮处)

结果很显然,当任何一个数乘以0时,其值为0,那么在后续的线性运算过程中其仍然为0,便可表示,其信息被忽略,因为到下一层时,其未产生信息叠加。

同理,对于输入门,我们有:

lstm结构图_深入理解RNN与LSTM_第10张图片
输入门(高亮处)

而输入门主要控制对输入的信息进行过滤,即在输入时选择性地抛弃某些信息,而抛弃的信息,即为输入门中输出为0的特征维度。同时,在时间步

,原输入应为:
,按照传统的RNN的前向传播,输入应经过线性变换后进行激活,并且激活函数通常使用
,即:

lstm结构图_深入理解RNN与LSTM_第11张图片
输出门(高亮处)
上述输入的变化,可以对应RNN的输入过程。

由于加了门机制,则需要对输入的信息,进行过滤,而输入信息在LSTM中包含:细胞状态、隐藏状态、当前时间步输入。其中隐藏状态、当前时间步,已经作为输入经过传统的RNN变换得到

,还剩下细胞状态,因此需要进一步将细胞状态与
融合,并得到新的细胞状态:

lstm结构图_深入理解RNN与LSTM_第12张图片
细胞状态更新

,其中
表示element-wise乘积,直观地理解:当前时间的细胞状态,等于之前时间的细胞状态经过遗忘门过滤再叠加输入门的信息和。

在输出门中,同样采用相同的方式得到门概率分布:

。输出门的作用在于,对于要输出给下一个时间步的信息,进行一定地过滤,有选择性地保留和去除之前时间步的某些数据。因此,有
。得到
后,便可进一步得到
,其过程与RNN一致。至此,LSTM的前向传播过程即以结束。

LSTM的结构有效地解决了RNN的短期依赖瓶颈。但是从模型结构可以看出,相较于RNN,LSTM含有更多的参数需要学习,从而导致LSTM的学习速度大大降低。

上述公式推导过程中,同样可以采用拼接的方式,使得
,而

对前向传播的过程进行整理,可得:

你可能感兴趣的:(lstm结构图)