图解LSTM——一文吃透LSTM

(欢迎大家关注我的微信公众号机器学习面试基地”,之后将在公众号上持续记录本人从非科班转到算法路上的学习心得、笔经面经、心得体会。未来的重点也会主要放在机器学习面试上!)

图解LSTM——一文吃透LSTM

  • v0版(20210817):本版本将通过图解LSTM的方式,逐步剖析LSTM的内部结构,力求把LSTM的结构和公式刻在大家的脑海中。当前版本并不会对LSTM的有效机制原因和反向传播进行分析,重点在于LSTM的结构展现上,也不会牵扯别的太多,就一个目的,让大家看透“LSTM”。

一、理解第一步:宏观认识

  • LSTM能干什么?

    简单说,就是将一串向量序列,转换为另一串含有更多特征信息的向量序列,也可以理解为embedding的一种方式。在实际应用中,输入的可以是一维的信号序列(也是向量序列,不过每个向量的长度为1),也可以是one-hot编码后的序列。那么输出的则是包含了特征信息或者上下文信息的向量序列(因此可以理解为特征向量序列也可以理解为embedding向量),公式中也称作隐藏单元向量。如下图所示:

    图解LSTM——一文吃透LSTM_第1张图片

  • 为什么要用LSTM,不能用RNN?

    LSTM的提出是为了解决RNN“长期依赖”的问题。简单来说,LSTM相比于RNN能够应对更长的输入序列。

  • 如何分类或者回归?

    但是有时候最终的任务并不是得到一个特征向量序列就结束了,而是想要做分类或者回归。这时候利用得到的输出向量序列来进行分类和回归即可,因为它们包含了丰富的特征信息。对于分类,可以将最后一个时刻的向量通过全连接层得到输出概率,如下图所示:

    对于回归,可以利用同一个全连接层将每个时刻的向量都转换为一个值,得到最终的回归序列。对于预测问题其实也是一样的,只不过回归的目标是未来时刻的序列:

    图解LSTM——一文吃透LSTM_第2张图片

二、理解第二步:LSTM的运行机制

上面第一步让我们对LSTM的输入和输出有了宏观认识,那么更进一步,已知这些输入,是如何得到上面的输出的呢?这里首先需要对输入输出进一步细化:在上面的图示中,我们可以看到LSTM产生了一个隐藏单元向量h,但是在实际过程中,随着时刻产生了不只是隐藏单元向量,还有一个细胞状态c。这里的细胞状态起着存储历史有效信息的作用:

图解LSTM——一文吃透LSTM_第3张图片

那么LSTM的整个运行机制是什么样的呢?

  • 初始化一个隐藏单元向量h0细胞状态c0

图解LSTM——一文吃透LSTM_第4张图片

  • 与当前时刻的输入x1一起输入到LSTM单元中,得到当前时刻的隐藏单元向量h1细胞状态c1
    图解LSTM——一文吃透LSTM_第5张图片

  • 然后将x2和隐藏单元向量h1细胞状态c1一起输入到LSTM单元中,得到隐藏单元向量h2细胞状态c2

图解LSTM——一文吃透LSTM_第6张图片

  • 以此循环,得到每个时刻的隐藏单元向量ht细胞状态ct

图解LSTM——一文吃透LSTM_第7张图片

而在LSTM单元中,还有相比于RNN更复杂的结构,比如遗忘门、输入门、输出门、细胞状态,这也是相比于RNN主要的不同之处。这些所谓的门结构,其实就是一个加上了激活函数的全连接层,然后输出0-1的数值,这些数值则表示特征信息能够保留的程度,0表示全部不保留,1则表示全部保留。

三、理解第三步:LSTM的核心——LSTM unit细节和LSTM公式

  • 先祭出让人头大的公式:(放了个公式大家肯定也记不住,后面会拆分理解和记忆)

    { f n = σ ( W f [ x n h n − 1 ] + b f ) i n = σ ( W i [ x n h n − 1 ] + b i ) c n = f n ⊙ c n − 1 + i n ⊙ tanh ⁡ ( W c [ x n h n − 1 ] + b c ) o n = σ ( W o [ x n h n − 1 ] + b o ) h n = o n ⊙ tanh ⁡ ( c n ) \left\{\begin{array}{l}\boldsymbol{f}_{\boldsymbol{n}}=\sigma\left(\boldsymbol{W}_{\boldsymbol{f}}\left[\begin{array}{c}x_{n} \\ \boldsymbol{h}_{n-1}\end{array}\right]+\boldsymbol{b}_{f}\right) \\ \boldsymbol{i}_{\boldsymbol{n}}=\sigma\left(\boldsymbol{W}_{\boldsymbol{i}}\left[\begin{array}{c}x_{n} \\ \boldsymbol{h}_{\boldsymbol{n}-\mathbf{1}}\end{array}\right]+\boldsymbol{b}_{\boldsymbol{i}}\right) \\ \boldsymbol{c}_{\boldsymbol{n}}=\boldsymbol{f}_{n} \odot \boldsymbol{c}_{\boldsymbol{n}-\mathbf{1}}+\boldsymbol{i}_{\boldsymbol{n}} \odot \tanh \left(\boldsymbol{W}_{\boldsymbol{c}}\left[\begin{array}{c}x_{n} \\ \boldsymbol{h}_{\boldsymbol{n}-1}\end{array}\right]+\boldsymbol{b}_{\boldsymbol{c}}\right) \\ \boldsymbol{o}_{\boldsymbol{n}}=\sigma\left(\boldsymbol{W}_{\boldsymbol{o}}\left[\begin{array}{c}x_{n} \\ \boldsymbol{h}_{n-\mathbf{1}}\end{array}\right]+\boldsymbol{b}_{\boldsymbol{o}}\right) \\ \boldsymbol{h}_{\boldsymbol{n}}=\boldsymbol{o}_{\boldsymbol{n}} \odot \tanh \left(\boldsymbol{c}_{\boldsymbol{n}}\right)\end{array}\right. fn=σ(Wf[xnhn1]+bf)in=σ(Wi[xnhn1]+bi)cn=fncn1+intanh(Wc[xnhn1]+bc)on=σ(Wo[xnhn1]+bo)hn=ontanh(cn)

    其中, f n f_n fn​​​, i n i_n in​​​, c n c_n cn​​​, o n o_n on​​​, h n h_n hn​​​​​​​​​​​分别表示当前时刻的遗忘门(forget)、输入门(input)、细胞状态(cell)、输出门(output)、隐藏单元向量(hidden),遗忘门的作用就是决定上一时刻的输出信息需要丢弃多少,输入门的作用在于判断当前时刻​​​​的输入信息哪些是有用的,需要留下来多少。输出门则是综合当前时刻信息和过去时刻信息后决定输出哪些信息。细胞状态则是可以看作一个存储库,存着各个时刻的有用信息。隐藏单元向量则是输入到下一时刻的信息。

    ​ 整个公式用一张图来表示就是:
    图解LSTM——一文吃透LSTM_第8张图片

  • 公式拆分、逐步理解
    图解LSTM——一文吃透LSTM_第9张图片
    图解LSTM——一文吃透LSTM_第10张图片
    图解LSTM——一文吃透LSTM_第11张图片
    图解LSTM——一文吃透LSTM_第12张图片
    图解LSTM——一文吃透LSTM_第13张图片
    图解LSTM——一文吃透LSTM_第14张图片

先写到这儿,后续再整理吧! 20210817

你可能感兴趣的:(深度学习,深度学习,lstm)