LSTM-基本原理-前向传播与反向传播过程推导

前言

最近在实践中用到LSTM模型,一直在查找相关资料,推导其前向传播、反向传播过程。
LSTM有很多变体,查到的资料的描述也略有差别,且有一些地方让我觉得有些困惑。目前查到的资料中我认为这个国外大神的博客写的比较清晰:
http://arunmallya.github.io/writeups/nn/lstm/index.html#/
这个博客中的有些步骤有一定跳跃性,本文中的描述主要基于这篇博客中的实现过程进行更细致的推导,在此分享。本人能力有限,如果有不妥当之处,欢迎大家交流、指正。

LSTM算法介绍

长短时记忆网络(Long Short Term Memory Network, LSTM),它成功的解决了原始循环神经网络的缺陷,成为当前最流行的RNN,在语音识别、图片描述、自然语言处理等许多领域中成功应用。
本人在应用方面也还没有太多涉猎,建议大家先看一下下面这篇经典博客:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
不想看英文的可以看一下中文版:
https://blog.csdn.net/roslei/article/details/61912618

基本网络结构&前向传播

先附上LSTM整体结构图:
LSTM-基本原理-前向传播与反向传播过程推导_第1张图片

后面的所有推导都是基于这张图的结构。
LSTM中比较重要的改进是加入了细胞状态 : Ct代表t时刻的细胞状态。
下面具体展开写一下各部分结构:

输入以及门计算部分

LSTM-基本原理-前向传播与反向传播过程推导_第2张图片
上述公式中表示了LSTM的四个基本结构:
输入门、输出门、遗忘门、用于更新细胞状态的部分(图中at 部分,这个好像没有专门的名称)
其中xt 为t时刻的输入,我们设定其大小为: n X 1
注:此处x直接把偏置项涵盖了,即加入了x0=1
ht 为t时刻的隐藏状态,我们设定其大小为: d X 1
ct 为t时刻的细胞状态,设定其大小为: d X 1
注:隐藏状态和细胞状态一般情况下维度一致。
W*: 维度大小:d X n
U* : 维度大小:d X d
注:W* 与 U*即为模型的参数矩阵,共八个,也就是说,训练LSTM时训练的就是这八个矩阵。
σ 代表sigmoid函数 , tanh 代表 tanh函数。
这里的σ和 tanh是按元素操作,以σ为例,σ(X)即对向量X中的每个元素Xi分别计算
σ(Xi),σ(X)与 X 维度相同。
图中 zt是为了表述方便采用的写法。将参数与输入分别整合为一个矩阵表示。

细胞状态更新部分

LSTM-基本原理-前向传播与反向传播过程推导_第3张图片
上图中公式即为细胞状态更新公式:
ct=it⊙at+ft⊙ct-1
其中⊙表示按元素乘(两矩阵维度一致,相同位置元素相乘,结果矩阵维度不变)。

输出部分

LSTM-基本原理-前向传播与反向传播过程推导_第4张图片
根据如下公式:
ht=ot⊙tanh(ct)
输出t时刻隐藏状态ht.
至此,LSTM的网络架构,前向传播部分就梳理结束了。

反向传播:梯度计算

为了理解反向传播过程,先看一下前向传播过程按时间将网络结构展开的效果图:
LSTM-基本原理-前向传播与反向传播过程推导_第5张图片
可以看到,t时刻的细胞状态ct对当前时刻隐藏状态ht和下一时刻的细胞状态ct+1都有贡献。所以计算ct位置的梯度时,需要考虑来自这两部分的梯度(全导数法则)。
注:此处默认的是多输入-多输出的情况,即每一时刻的隐藏状态ht均参与损失函数计算。
对应反向传播示意图如下图:
LSTM-基本原理-前向传播与反向传播过程推导_第6张图片
反向传播的终极目标是为了计算梯度,更新参数,所以
要计算损失函数对W* 与 U*的偏导数。
下面一步步推导:
LSTM-基本原理-前向传播与反向传播过程推导_第7张图片
不考虑损失函数的形式,这里我们泛化地设定:
这里写图片描述
然后将误差逐层反向传播。
注意:上图中的δct的等式右端其实只是来自ht部分的梯度,下文计算来自ct+1的梯度,二者相加才是真正的δct
上图中以及后文会用到以下函数的导数计算公式:
LSTM-基本原理-前向传播与反向传播过程推导_第8张图片

LSTM-基本原理-前向传播与反向传播过程推导_第9张图片
注意,根据上图中的
δct-1=δct⊙ft
我们可以得到:
δct=δct+1⊙ft+1
其实这部分只是来自ct+1的梯度。
综合前两张图中关于δct的计算可得:
δct=δht⊙ot⊙(1-tanh2(ct))+δct+1⊙ft+1
好,下面我们继续反向传播:
LSTM-基本原理-前向传播与反向传播过程推导_第10张图片

LSTM-基本原理-前向传播与反向传播过程推导_第11张图片
这里的写法稍微有些跳跃,其实不同的W* 与 U*的偏导数计算类似,所以原博客作者把他们整合在了一个表达式中。
这里以Wc为例具体算一下:
LSTM-基本原理-前向传播与反向传播过程推导_第12张图片
类似地可以计算其他参数矩阵的梯度,
最终写成图中的整合矩阵形式:
δWt =δzt X (ItT
至此,我们求出了在t时刻,损失函数相对于各参数的梯度。
LSTM-基本原理-前向传播与反向传播过程推导_第13张图片
最终,根据上式累加不同时刻的梯度,进行参数更新。

小结

以上就是LSTM基本的网络结构,以及前向、反向传播过程。由于时间和水平有限,文中可能有不妥当之处,欢迎大家批评指正,后续会不断修改。

你可能感兴趣的:(机器学习,模型推导)