长短期记忆网络(LSTM)是什么怎么工作的呢?

深度学习系列

第一篇 局部最优点+鞍点+学习率的调节
第二篇 并行计算 深度学习 机器学习
第三篇 长短期记忆网络(LSTM)是什么怎么工作的呢?


文章目录

  • 深度学习系列
  • 前言
  • 一、SampleRNN其实有不同的形式
  • 二、LSTM基本知识
    • 1. 为什么使用LSTM?
      • 先来看RNN的缺点
      • RNN为什么会梯度消失/爆炸
    • 2.LSTM做了什么?
    • 3.LSTM的缺点
  • 四、Reference
  • 总结


前言

LSTM,全称Long Short-Term Memory,可以说是RNN的一个进阶的版本,它们都是考虑了过去时间点的数据,综合调整当前的网络参数。我们现在说使用了RNN,一般指的就是使用了LSTM,而普通的RNN,应该叫SimpleRNN。


一、SampleRNN其实有不同的形式

简单的例如下面的Elman network 或者 Jordan network。它们都是简单的RNN,但是形式有所不同。不同点在于传入下一个的RNNCell的参数不一样,Elman Network是hidden层的输出作为下一层的输入,而Jordan Network是上一层最终的输入作为输入。
长短期记忆网络(LSTM)是什么怎么工作的呢?_第1张图片
就如下图,对于1层的RNN,序号1指的是hidden层的输出,而序号2才是最终的output输出。
长短期记忆网络(LSTM)是什么怎么工作的呢?_第2张图片

二、LSTM基本知识

1. 为什么使用LSTM?

先来看RNN的缺点

RNN的关键点就是他们可以用来连接先前的信息到当前的任务上。例如我们使用RNN去给句子填词,假如我们试着预测 “the clouds are in the sky” 最后的词sky,我们并不需要任何其他的上下文,下一个词很显然就应该是 sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,普通RNN 可以学会使用先前的信息。但是!! 在一些更加复杂的场景。如我们去预测“I grew up in France , … ,I speak fluent French”最后的词French。我们需要知道前文对应是什么国家,我们是需要用到离当前位置很远的 France 的上下文的。但是中间…说明有很多其他句子,这使得French和France的间隔很大。在间隔不断增大时,SimpleRNN会丧失学习到连接如此远的信息的能力

另外一个点就是RNN有梯度消失的缺点,但是笔者认为这个缺点和上面那个缺点是有共通性的。首先,我们说的梯度消失,指的是到了深处梯度消失,在浅一点的位置并不会消失,也正是深处的梯度消失了,浅位置对于深位置的依赖关系难以学到(更新不了)。

RNN为什么会梯度消失/爆炸

主要是在反向传播的过程中,有了连乘。如下图,这是去推导公式的后,把结果整理写成通式。我们可以看到主要是hidden在求导的时候是连乘
RNN的梯度
而LSTM不一样,他往下传播当前状态的时候,除了hidden,还假如了cell(记忆细胞,也就是C),而C在更新的时候,采用了加法!!!!如下图,这就是最重点的!我们去求梯度的时候,有一部分是通过hidden这个中间节点(通俗说法,就是链式求导一个节点)去求梯度,但是还有一部分是通过C这个中间节点去求的,hidden部分还是会有梯度消失的问题,但是C部分求导后,其中的式子有LSTM中几个门对应的权重,我们可以调整门的权重,去控制C部分的梯度,所以C+hidden的总梯度就不会太小。(所以说LSTM加了门机制是重要的不同点)

长短期记忆网络(LSTM)是什么怎么工作的呢?_第3张图片

具体的推导过程可以看这个视频:大白话讲解LSTM长短期记忆网络 如何缓解梯度消失,手把手公式推导反向传播

这里也有一篇在公式上给出过程的博客:
LSTM缓解梯度消失的原因

2.LSTM做了什么?

长短期记忆网络(LSTM)是什么怎么工作的呢?_第4张图片上图是LSTM的一个基本单元,它有4个输入,分别对应3个门和一个输入值,这3个门都有对应的权重矩阵相乘,也有相应的激活函数。这个激活函数一般为sigmoid,使用它能让门的输出值为基本为0或者1,也就意味着门是关闭还是打开。(0和输入值相乘为0,说明输入门是关闭的)

具体过程是;输入值Z和其权重相乘后做tanh函数激活,乘是输入门Zi计算的值(0或者1),Zi也是乘相应矩阵再激活,此时输入门和输入值相乘后为g(z)f(zi),而遗忘门Zf中的计算的数值为f(zf) 会和记忆细胞Cell中的值c相乘再加上g(z)f(zi),如果f(zf) 为0,说明会遗忘以前的Cell中的值,如果为1,Cell中的值会被更新,然后再做激活。最后是输出门计算后的值f(z0),这个值乘上h(c’),作为输出。

当多个LSTM基本单元连起来后,就是我们下面这种结构,这个图是核心,必须看懂。我们可以看到每个单元的C(存储单元的值)和h(hidden层的值)都会往下一个单元传递。具体的传递过程,我觉得可以参考这篇文章:时序数据处理模型:RNN与LSTM总结
长短期记忆网络(LSTM)是什么怎么工作的呢?_第5张图片

3.LSTM的缺点

LSTM有一个缺点就是比起普通的neural network,它的参数多,因为它需要多个门来控制参数的变化,每个门都有对应的w和输入的值相乘,相当于有4倍的参数量。这些W的值也是通过gradient descent去更新

四、Reference

Understanding LSTM Networks
理解 LSTM 网络


总结

以上就是我个人对LSTM的理解,希望配合上其他文章,能让初学者更容易理解。如果觉得有用,请大家点赞支持!!!!

你可能感兴趣的:(DeepLearning,lstm,深度学习,rnn)