【机器学习】LSTM 讲解

2. LSTM

2.1. 长期依赖问题

标准 RNN 结构在理论上完全可以实现将最初的信息保留到即使很远的时刻,但是在实践中发现 RNN 会受到短时记忆的影响。如果一条序列足够长,那它们将很难将信息从较早的时刻传送到后面的时刻。 因此,如果正在尝试处理一段文本进行预测,RNN 可能从一开始就会遗漏重要信息。比如我们尝试预测 “I grew up in France … I speak fluent French” 这句话的最后一个词 ”French“ 。当前的信息(“I speak fluent”)表明接下来的单词是很可能是语言的名字。但是需要哪种语言,我们就要根据离当前位置很远的 “France” 来确定。这就说明相关信息和当前预测词的位置之间的间隔可能非常大,随着这间隔不断变大,RNN 就会失去学习连接如此远的信息的能力。 这就是我们上面提到的 RNN 最致命的缺点。

为了解决这个问题,提出了 LSTM 。

2.2. 网络结构

LSTM 属于 RNN 的扩展模型,二者的区别仅在于每个单元内部结构不同。LSTM 单元结构如下。

【机器学习】LSTM 讲解_第1张图片

其中,黄色矩形表示一层神经网络,包含权重和激活函数,矩形中的符号表明激活函数的类型, σ \sigma σ 对应 sigmoid 函数, t a n h \rm tanh tanh 对应 tanh 函数;粉色(椭)圆表示逐元素操作,比如粉色(椭)圆中为乘号表明矩阵进行对应元素相乘(点乘)操作, t a n h \rm tanh tanh 表明进行逐元素取 tanh 值。

下图展示了 LSTM 单元的完整前向传播过程。

从”遗忘门“、”输入门“和”输出门“,这三个”门“的角度来理解 LSTM 单元。

之所以称之为”门“,是考虑到生活中的”门“存在”开/闭“两种状态。LSTM 单元中的”门“也是存在”开/闭“两种状态,”开“表示全部(绝大部分)信息都可以经过”门“流出,”闭“表示全部(绝大部分)信息都不能经过”门“流出,而是被”门“过滤掉。由于 sigmoid 函数非常适合二分类,所以该函数在 LSTM 单元中起到”门“过滤的作用,用于控制信息是否流出(流出量)。

  • 遗忘门

    ”遗忘门“决定了前一个单元的状态 c t − 1 c_{t-1} ct1 有多少信息保留到当前单元状态 c t c_t ct 中。对应图中过程 [ h t − 1 , x t ] → f t [h_{t-1},x_t]\rightarrow f_t [ht1,xt]ft

  • 输入门

    ”输入门“决定了当前单元的输入 x t x_t xt 有多少信息保存到单元状态 c t c_t ct 。对应图中过程 [ h t − 1 , x t ] → i t [h_{t-1},x_t]\rightarrow i_t [ht1,xt]it

  • 输出门

    ”输出门“用于控制当前单元的状态 c t c_t ct 有多少信息输出到当前输出值 h t h_t ht 。对应图中过程 [ h t − 1 , x t ] → o t [h_{t-1},x_t]\rightarrow o_t [ht1,xt]ot 。​

模型单元的思想可以理解为, [ h t − 1 , x t ] [h_{t-1},x_t] [ht1,xt] 经过遗忘门确定保留多少前一个单元的信息, c t − 1 c_{t-1} ct1 σ ( W x f x t + W h f h t − 1 + b f ) \sigma(W_{xf}x_t+W_{hf}h_{t-1}+b_f) σ(Wxfxt+Whfht1+bf) 按位点乘实现筛选出要保留的信息; σ ( W x i x t + W h i h t − 1 + b i ) \sigma(W_{xi}x_t+W_{hi}h_{t-1}+b_i) σ(Wxixt+Whiht1+bi) t a n h ( W x g x t + W h g h t − 1 + b g ) {\rm tanh}(W_{xg}x_t+W_{hg}h_{t-1}+b_g) tanh(Wxgxt+Whght1+bg) 按位点乘实现从外部输入信息 x t x_t xt 中筛选出需要保留的信息,过滤到无用信息;将保留的原始信息和保留的外部信息按位相加,得到当前单元包含的信息 c t c_t ct t a n h ( c t ) {\rm tanh}(c_t) tanh(ct) 用于将每个单元的信息统一到一定范围内,再与 σ ( W x o x t + W h o h t − 1 + b o ) \sigma(W_{xo}x_t+W_{ho}h_{t-1}+b_o) σ(Wxoxt+Whoht1+bo) 按位点乘筛选出当前单元的全部信息中可以用于评估单元优劣的信息 h t h_t ht,对全部 h t h_t ht 进一步处理可以得到用于评估模型优劣的损失函数,同时也会直接传入到下一个单元,循环往复。

总结一下,整个流程是分为三个大部分,对应着三个”门“的操作。遗忘门部分筛选有用的内部信息,输入门筛选有用的外部信息,将两部分信息整合,输出门筛选用于评估单元优劣的信息。可以看到,每次的筛选操作都是通过 sigmoid 函数对 [ h t − 1 , x t ] [h_{t-1},x_t] [ht1,xt] 的线性映射进行非线性激活完成的。

2.3. 前向传播与反向传播

  • 前向传播

    前面已经讲解了。

  • 反向传播

    还是以计算图的形式说明反向传播过程。存在如下公式:
    f t = σ ( W x f x t + W h f h t − 1 + b f ) i t = σ ( W x i x t + W h i h t − 1 + b i ) g t = t a n h ( W x g x t + W h g h t − 1 + b g ) o t = σ ( W x o x t + W h o h t − 1 + b o ) c t = c t − 1 ⊙ f t + g t ⊙ i t h t = t a n h ( c t ) ⊙ o t L = ∑ l o s s ( h t , y t ) \begin{align} f_t&=\sigma(W_{xf}x_t+W_{hf}h_{t-1}+b_f) \tag{2.1}\\ i_t&=\sigma(W_{xi}x_t+W_{hi}h_{t-1}+b_i) \tag{2.2}\\ g_t&={\rm tanh}(W_{xg}x_t+W_{hg}h_{t-1}+b_g) \tag{2.3}\\ o_t&=\sigma(W_{xo}x_t+W_{ho}h_{t-1}+b_o) \tag{2.4}\\ c_t&=c_{t-1}\odot f_t+g_t\odot i_t \tag{2.5}\\ h_t&={\rm tanh}(c_t)\odot o_t \tag{2.6} \\ L&=\sum loss(h_t,y_t) \tag{2.7} \\ \end{align} ftitgtotcthtL=σ(Wxfxt+Whfht1+bf)=σ(Wxixt+Whiht1+bi)=tanh(Wxgxt+Whght1+bg)=σ(Wxoxt+Whoht1+bo)=ct1ft+gtit=tanh(ct)ot=loss(ht,yt)(2.1)(2.2)(2.3)(2.4)(2.5)(2.6)(2.7)
    一个单元的计算图如下。灰色框圈出的是一个单元涉及的计算关系,其他单元都可以类似地画出。

    【机器学习】LSTM 讲解_第2张图片

    我们引入 L t = l o s s ( h t , y t )   ( t = 1 , 2 , … , T ) L_t=loss(h_t,y_t)\space (t=1,2,\dots,T) Lt=loss(ht,yt) (t=1,2,,T) ,因此 L L L 可以表示为 L = ∑ t = 1 T L t L=\sum_{t=1}^TL_t L=t=1TLt 。反向传播过程如下。

    以计算 ∂ L ∂ W h f \frac{\partial L}{\partial W_{hf}} WhfL 为例推导公式,其他参数类似,推导的思路是根据反向传播过程按顺序推导每个结点代表的链式偏导。

    考虑最特别的 T T T 时刻,计算出损失函数(值)关于 T T T 时刻各个变量的偏导
    ∂ L ∂ L T ∂ L ∂ h T = ∂ L ∂ L T ∂ L T ∂ h T ∂ L ∂ o T = ∂ L ∂ h T ∂ h T ∂ o T ∂ L ∂ c T = ∂ L ∂ h T ∂ h T ∂ c T ∂ L ∂ f T = ∂ L ∂ c T ∂ c T ∂ f T ∂ L ∂ i T = ∂ L ∂ c T ∂ c T ∂ i T ∂ L ∂ g T = ∂ L ∂ c T ∂ c T ∂ g T ∂ L ∂ W h f ⟨ T ⟩ = ∂ L ∂ f T ∂ f T ∂ W h f + ∂ L ∂ i T ∂ i T ∂ W h f + ∂ L ∂ g T ∂ g T ∂ W h f \begin{align} \frac{\partial L}{\partial L_T} &\notag \\\notag \\ \frac{\partial L}{\partial h_T} &= \frac{\partial L}{\partial L_T}\frac{\partial L_T}{\partial h_T} \notag\\\notag \\ \frac{\partial L}{\partial o_T} &= \frac{\partial L}{\partial h_T} \frac{\partial h_T}{\partial o_T} \notag \\\notag \\ \frac{\partial L}{\partial c_T} &= \frac{\partial L}{\partial h_T} \frac{\partial h_T}{\partial c_T} \notag \\\notag \\ \frac{\partial L}{\partial f_T}&=\frac{\partial L}{\partial c_T}\frac{\partial c_T}{\partial f_T} \notag\\\notag \\ \frac{\partial L}{\partial i_T}&=\frac{\partial L}{\partial c_T}\frac{\partial c_T}{\partial i_T} \notag\\\notag \\ \frac{\partial L}{\partial g_T}&=\frac{\partial L}{\partial c_T}\frac{\partial c_T}{\partial g_T} \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle T \right\rangle}} &= \frac{\partial L}{\partial f_T} \frac{\partial f_T}{\partial W_{hf}} + \frac{\partial L}{\partial i_T} \frac{\partial i_T}{\partial W_{hf}}+\frac{\partial L}{\partial g_T} \frac{\partial g_T}{\partial W_{hf}} \notag \end{align} LTLhTLoTLcTLfTLiTLgTLWhfTL=LTLhTLT=hTLoThT=hTLcThT=cTLfTcT=cTLiTcT=cTLgTcT=fTLWhffT+iTLWhfiT+gTLWhfgT

    其中, ∂ L ∂ W h f ⟨ T ⟩ \frac{\partial L}{\partial W_{hf}^{\left\langle T \right\rangle}} WhfTL 表示 T T T 时刻对损失函数(值)关于 W h f W_{hf} Whf 偏导的贡献,满足 ∂ L ∂ W h f = ∑ t = 1 T ∂ L ∂ W h f ⟨ t ⟩ \frac{\partial L}{\partial W_{hf}} = \sum\limits_{t=1}^T \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}} WhfL=t=1TWhftL

    根据式 ( 2.1 ) ∼ ( 2.7 ) (2.1)\sim (2.7) (2.1)(2.7) 将上面各式计算出来。 T T T 时刻各个变量的偏导总结如下。
    ∂ L ∂ L T = 1 ∂ L ∂ h T = ∂ L T ∂ h T ∂ L ∂ o T = ∂ L T ∂ h T t a n h ( c T ) ∂ L ∂ c T = ∂ L T ∂ h T o T t a n h ′ ( ⋅ ) ∂ L ∂ f T = ∂ L T ∂ h T o T t a n h ′ ( ⋅ ) c t − 1 ∂ L ∂ i T = ∂ L T ∂ h T o T t a n h ′ ( ⋅ ) g T ∂ L ∂ g T = ∂ L T ∂ h T o T t a n h ′ ( ⋅ ) i T ∂ L ∂ W h f ⟨ T ⟩ = ∂ L ∂ f T ∂ f T ∂ W h f = ∂ L ∂ h T o T t a n h ′ ( ⋅ ) c T − 1 σ ′ ( ⋅ ) h T − 1 \begin{align} \frac{\partial L}{\partial L_T} &=1\notag \\\notag \\ \frac{\partial L}{\partial h_T} &= \frac{\partial L_T}{\partial h_T} \notag\\\notag \\ \frac{\partial L}{\partial o_T} &= \frac{\partial L_T}{\partial h_T} {\rm tanh}(c_T) \notag \\\notag \\ \frac{\partial L}{\partial c_T} &= \frac{\partial L_T}{\partial h_T}o_T{\rm tanh'(·)} \notag \\\notag \\ \frac{\partial L}{\partial f_T}&=\frac{\partial L_T}{\partial h_T}o_T {\rm tanh'(·)}c_{t-1} \notag\\\notag \\ \frac{\partial L}{\partial i_T}&=\frac{\partial L_T}{\partial h_T}o_T {\rm tanh'(·)}g_T \notag\\\notag \\ \frac{\partial L}{\partial g_T}&=\frac{\partial L_T}{\partial h_T}o_T {\rm tanh'(·)}i_T \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle T \right\rangle}} &= \frac{\partial L}{\partial f_T} \frac{\partial f_T}{\partial W_{hf}} =\frac{\partial L}{\partial h_T}o_T {\rm tanh'(·)}c_{T-1}\sigma'(·) h_{T-1} \notag \end{align} LTLhTLoTLcTLfTLiTLgTLWhfTL=1=hTLT=hTLTtanh(cT)=hTLToTtanh()=hTLToTtanh()ct1=hTLToTtanh()gT=hTLToTtanh()iT=fTLWhffT=hTLoTtanh()cT1σ()hT1

    t = 1 , 2 , … , T − 1 t=1,2,\dots,T-1 t=1,2,,T1 时,计算出损失函数(值)关于 t t t 时刻刻个变量的偏导
    ∂ L ∂ L t ∂ L ∂ h t = ∂ L ∂ L t ∂ L t ∂ h t + ∂ L ∂ o t + 1 ∂ o t + 1 ∂ h t + ∂ L ∂ f t + 1 ∂ f t + 1 ∂ h t + ∂ L ∂ i t + 1 ∂ i t + 1 ∂ h t + ∂ L ∂ g t + 1 ∂ g t + 1 ∂ h t ∂ L ∂ o t = ∂ L ∂ h t ∂ h t ∂ o t ∂ L ∂ c t = ∂ L ∂ h t ∂ h t ∂ c t + ∂ L ∂ c t + 1 ∂ c t + 1 ∂ c t ∂ L ∂ f t = ∂ L ∂ c t ∂ c t ∂ f t ∂ L ∂ i t = ∂ L ∂ c t ∂ c t ∂ i t ∂ L ∂ g t = ∂ L ∂ c t ∂ c t ∂ g t ∂ L ∂ W h f ⟨ t ⟩ = ∂ L ∂ f t ∂ f t ∂ W h f + ∂ L ∂ i t ∂ i t ∂ W h f + ∂ L ∂ g t ∂ g t ∂ W h f \begin{align} \frac{\partial L}{\partial L_t} &\notag \\\notag \\ \frac{\partial L}{\partial h_t} &= \frac{\partial L}{\partial L_t}\frac{\partial L_t}{\partial h_t} + \frac{\partial L}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial h_t} +\frac{\partial L}{\partial f_{t+1}} \frac{\partial f_{t+1}}{\partial h_{t}} + \frac{\partial L}{\partial i_{t+1}} \frac{\partial i_{t+1}}{\partial h_{t}}+\frac{\partial L}{\partial g_{t+1}} \frac{\partial g_{t+1}}{\partial h_{t}} \notag\\\notag \\ \frac{\partial L}{\partial o_t} &= \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial o_t} \notag \\\notag \\ \frac{\partial L}{\partial c_t} &= \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial c_t} + \frac{\partial L}{\partial c_{t+1}} \frac{\partial c_{t+1}}{\partial c_t} \notag \\\notag \\ \frac{\partial L}{\partial f_t}&=\frac{\partial L}{\partial c_t}\frac{\partial c_t}{\partial f_t} \notag\\\notag \\ \frac{\partial L}{\partial i_t}&=\frac{\partial L}{\partial c_t}\frac{\partial c_t}{\partial i_t} \notag\\\notag \\ \frac{\partial L}{\partial g_t}&=\frac{\partial L}{\partial c_t}\frac{\partial c_t}{\partial g_t} \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}} &= \frac{\partial L}{\partial f_t} \frac{\partial f_t}{\partial W_{hf}} + \frac{\partial L}{\partial i_t} \frac{\partial i_t}{\partial W_{hf}}+\frac{\partial L}{\partial g_t} \frac{\partial g_t}{\partial W_{hf}} \notag \end{align} LtLhtLotLctLftLitLgtLWhftL=LtLhtLt+ot+1Lhtot+1+ft+1Lhtft+1+it+1Lhtit+1+gt+1Lhtgt+1=htLotht=htLctht+ct+1Lctct+1=ctLftct=ctLitct=ctLgtct=ftLWhfft+itLWhfit+gtLWhfgt
    根据式 ( 2.1 ) ∼ ( 2.7 ) (2.1)\sim (2.7) (2.1)(2.7) 将上面各式计算出来。 t   ( t = 1 , 2 , … , T − 1 ) t\space (t=1,2,\dots,T-1) t (t=1,2,,T1) 时刻各个变量的偏导总结如下(部分等式由于展开过长而不代入展开)。
    ∂ L ∂ L t = 1 ∂ L ∂ h t = ∂ L t ∂ h t + ∂ L t + 1 ∂ h t + 1 t a n h ( c t + 1 ) σ ′ ( ⋅ ) W h o + ∂ L t + 1 ∂ h t + 1 o t + 1 t a n h ′ ( ⋅ ) c t σ ′ ( ⋅ ) W h f + ∂ L t + 1 ∂ h t + 1 o t + 1 t a n h ′ ( ⋅ ) g t + 1 σ ′ ( ⋅ ) W h i + ∂ L t + 1 ∂ h t + 1 o t + 1 t a n h ′ ( ⋅ ) i t + 1 σ ′ ( ⋅ ) W h g ∂ L ∂ o t = ∂ L ∂ h t t a n h ( c t ) ∂ L ∂ c t = ∂ L ∂ h t o t t a n h ′ ( ⋅ ) + ∂ L ∂ c t + 1 f t + 1 ∂ L ∂ f t = ∂ L ∂ c t c t − 1 ∂ L ∂ i t = ∂ L ∂ c t g t ∂ L ∂ g t = ∂ L ∂ c t i t ∂ L ∂ W h f ⟨ t ⟩ = ∂ L ∂ f t ∂ f t ∂ W h f = ∂ L ∂ f t h t − 1 = ∂ L ∂ c t c t − 1 σ ′ ( ⋅ ) h t − 1 \begin{align} \frac{\partial L}{\partial L_t} &=1\notag \\\notag \\ \frac{\partial L}{\partial h_t} &= \frac{\partial L_t}{\partial h_t} + \frac{\partial L_{t+1}}{\partial h_{t+1}}{\rm tanh}(c_{t+1})\sigma'(·)W_{ho} +\frac{\partial L_{t+1}}{\partial h_{t+1}} o_{t+1}{\rm tanh'(·)}c_t\sigma'(·)W_{hf} + \frac{\partial L_{t+1}}{\partial h_{t+1}}o_{t+1}{\rm tanh'(·)}g_{t+1}\sigma'(·)W_{hi}+\frac{\partial L_{t+1}}{\partial h_{t+1}}o_{t+1}{\rm tanh'(·)}i_{t+1}\sigma'(·)W_{hg} \notag\\\notag \\ \frac{\partial L}{\partial o_t} &= \frac{\partial L}{\partial h_t} {\rm tanh} (c_t) \notag \\\notag \\ \frac{\partial L}{\partial c_t} &= \frac{\partial L}{\partial h_t} o_t{\rm tanh'(·)} + \frac{\partial L}{\partial c_{t+1}} f_{t+1} \tag{*} \\\notag \\ \frac{\partial L}{\partial f_t}&=\frac{\partial L}{\partial c_t}c_{t-1} \notag\\\notag \\ \frac{\partial L}{\partial i_t}&=\frac{\partial L}{\partial c_t}g_t \notag\\\notag \\ \frac{\partial L}{\partial g_t}&=\frac{\partial L}{\partial c_t}i_t \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}} &=\frac{\partial L}{\partial f_t} \frac{\partial f_t}{\partial W_{hf}}=\frac{\partial L}{\partial f_t} h_{t-1}=\frac{\partial L}{\partial c_t}c_{t-1}\sigma'(·) h_{t-1} \tag{**} \end{align} LtLhtLotLctLftLitLgtLWhftL=1=htLt+ht+1Lt+1tanh(ct+1)σ()Who+ht+1Lt+1ot+1tanh()ctσ()Whf+ht+1Lt+1ot+1tanh()gt+1σ()Whi+ht+1Lt+1ot+1tanh()it+1σ()Whg=htLtanh(ct)=htLottanh()+ct+1Lft+1=ctLct1=ctLgt=ctLit=ftLWhfft=ftLht1=ctLct1σ()ht1(*)(**)
    上面式 ( ∗ ) (*) () 没有计算出 ∂ L ∂ c t \frac{\partial L}{\partial c_t} ctL 的通项公式,只是给出了递推公式,对其归纳后得
    ∂ L ∂ c t = ∑ t = 1 T ∂ L ∂ h i o i t a n h ′ ( c i ) ( 1 + ∏ j = 2 i f j ) \frac{\partial L}{\partial c_t}=\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(c_i)\left( 1+\prod_{j=2}^i f_j\right) ctL=t=1ThiLoitanh(ci)(1+j=2ifj)
    进而计算出式 ( ∗ ∗ ) (**) ()
    ∂ L ∂ W h f ⟨ t ⟩ = c t − 1 σ ′ ( W x f x t + W h f h t − 1 + b f ) h t − 1 ∑ t = 1 T ∂ L ∂ h i o i t a n h ′ ( c i ) ( 1 + ∏ j = 2 i f j ) \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}}= c_{t-1}\sigma'(W_{xf}x_t+W_{hf}h_{t-1}+b_f)h_{t-1}\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(c_i)\left( 1+\prod_{j=2}^i f_j\right) WhftL=ct1σ(Wxfxt+Whfht1+bf)ht1t=1ThiLoitanh(ci)(1+j=2ifj)
    最后将全部的梯度贡献值相加,得
    ∂ L ∂ W h f = ∂ L ∂ h T o T t a n h ′ ( c T ) c T − 1 σ ′ ( W x f x T + W h f h T − 1 + b f ) h T − 1 + ∑ t = 1 T − 1 c t − 1 σ ′ ( W x f x t + W h f h t − 1 + b f ) h t − 1 ∑ t = 1 T ∂ L ∂ h i o i t a n h ′ ( c i ) ( 1 + ∏ j = 2 i f j ) \frac{\partial L}{\partial W_{hf}} = \frac{\partial L}{\partial h_T}o_T {\rm tanh'}(c_T)c_{T-1}\sigma'(W_{xf}x_T+W_{hf}h_{T-1}+b_f) h_{T-1} + \sum_{t=1}^{T-1} c_{t-1}\sigma'(W_{xf}x_t+W_{hf}h_{t-1}+b_f)h_{t-1}\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(c_i)\left( 1+\prod_{j=2}^i f_j\right) WhfL=hTLoTtanh(cT)cT1σ(WxfxT+WhfhT1+bf)hT1+t=1T1ct1σ(Wxfxt+Whfht1+bf)ht1t=1ThiLoitanh(ci)(1+j=2ifj)
    也可以不体现函数的参数,得到更简洁的形式
    ∂ L ∂ W h f = ∂ L ∂ h T o T t a n h ′ ( ⋅ ) c T − 1 σ ′ ( ⋅ ) h T − 1 + ∑ t = 1 T − 1 c t − 1 σ ′ ( ⋅ ) h t − 1 ∑ t = 1 T ∂ L ∂ h i o i t a n h ′ ( ⋅ ) ( 1 + ∏ j = 2 i f j ) \frac{\partial L}{\partial W_{hf}} = \frac{\partial L}{\partial h_T}o_T {\rm tanh'}(·)c_{T-1}\sigma'(·) h_{T-1} + \sum_{t=1}^{T-1} c_{t-1}\sigma'(·)h_{t-1}\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(·)\left( 1+\prod_{j=2}^i f_j\right) WhfL=hTLoTtanh()cT1σ()hT1+t=1T1ct1σ()ht1t=1ThiLoitanh()(1+j=2ifj)

    由于无法将 T T T 时刻的梯度贡献值与其他时刻的梯度贡献值统一表示,因此,对应上式中加号左右的两部分。

    上面计算出了 ∂ L ∂ W h f \frac{\partial L}{\partial W_{hf}} WhfL ,类似地也可以计算出 L L L W x f W_{xf} Wxf W h i W_{hi} Whi W x i W_{xi} Wxi W h g W_{hg} Whg W x g W_{xg} Wxg W h o W_{ho} Who W x o W_{xo} Wxo b f b_f bf b i b_i bi b g b_g bg b o b_o bo

    以下在讨论引入 L t L_t Lt 的原因,选读。

    不同于 RNN 反向传播公式的推导,RNN 并没有特意地引入 L t L_t Lt ,而 LSTM 反向传播公式的推导中却需要引入。我们不妨先不引入该符号,当计算 ∂ L ∂ h t   ( t = 1 , 2 , … , T − 1 ) \frac{\partial L}{\partial h_t}\space (t=1,2,\dots,T-1) htL (t=1,2,,T1) 时,我们可以找到两条从 L L L h t h_t ht 的路径,分别是 L → h t L\rightarrow h_t Lht L → h t + 1 → o t + 1 → h t L\rightarrow h_{t+1}\rightarrow o_{t+1}\rightarrow h_t Lht+1ot+1ht ,因此 ∂ L ∂ h t \frac{\partial L}{\partial h_t} htL 可以表示为 ∂ L ∂ h t = ∂ L ∂ h t + ∂ L ∂ h t + 1 ∂ h t + 1 ∂ o t + 1 ∂ o t + 1 ∂ h t \frac{\partial L}{\partial h_t}=\frac{\partial L}{\partial h_t}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial h_t} htL=htL+ht+1Lot+1ht+1htot+1 ,观察等式两边会发现,这显然不合理。

    出现这种情况的原因很好理解。 ∂ L ∂ h t \frac{\partial L}{\partial h_t} htL 只是一个符号,表示全部的从 L L L h t h_t ht 的路径(直接到达或经过其他任意结点中转到达)对应的链式求导之和; ∂ o t + 1 ∂ h t \frac{\partial o_{t+1}}{\partial h_t} htot+1 也只是符号,表达全部的从 o t + 1 o_{t+1} ot+1 h t h_t ht 的路径对应的链式求导之和,不过由于只存在一条路径,这使得 ∂ o t + 1 ∂ h t \frac{\partial o_{t+1}}{\partial h_t} htot+1 能够唯一地代表一条路径,所以我们也就不需要继续将 ∂ o t + 1 ∂ h t \frac{\partial o_{t+1}}{\partial h_t} htot+1 化为偏导连乘的形式了;类似的道理, ∂ L ∂ h t + 1 ∂ h t + 1 ∂ o t + 1 \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial o_{t+1}} ht+1Lot+1ht+1 可以由 ∂ L ∂ o t + 1 \frac{\partial L}{\partial o_{t+1}} ot+1L 代替,即 ∂ L ∂ o t + 1 = ∂ L ∂ h t + 1 ∂ h t + 1 ∂ o t + 1 \frac{\partial L}{\partial o_{t+1}}=\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial o_{t+1}} ot+1L=ht+1Lot+1ht+1,这正是因为从 L L L o t + 1 o_{t+1} ot+1 的路径唯一。综上,只有路径唯一时才能用符号 ∂ ∂ \frac{\partial}{\partial} 表示完整的链式偏导。

    重新考虑不引入符号 L t L_t Lt 出现的问题,等式左侧的符号 ∂ L ∂ h t \frac{\partial L}{\partial h_t} htL 对应了多条从 L L L h t h_t ht 的路径,等式右侧需要详细地将每条路径对应的链式偏导表达出来。如果想要唯一地表达路径 L → h t L\rightarrow h_t Lht (直接到达)则必须要引入另一个中间结点 L t L_t Lt ,从而构成新的路径 L → L t → h t L\rightarrow L_t\rightarrow h_t LLtht,对应的链式偏导为 ∂ L ∂ L t ∂ L t ∂ h t \frac{\partial L}{\partial L_t}\frac{\partial L_t}{\partial h_t} LtLhtLt

    形象地理解一下,我从家走到学校告诉同学:“放学的时候小心从我家到学校路上的狗”,同学傻了“那么多道,我怎么知道是哪条有狗啊!”,我细说“从我家先到布达拉宫,再到天安门,再到曹县,最后到学校的那条路上有狗;还有,从我家直通学校的路上也有,你可要小心啊!”,同学一听既害怕又感激,于是决定坐飞机回家。

    从这个例子中可以看出 ⌈ \lceil “家 → ⋯ → \rightarrow \dots\rightarrow 学校”有狗 ⌋ ⇔ ⌈ \rfloor\Leftrightarrow \lceil “家 → \rightarrow 布达拉宫 → \rightarrow 天安门 → \rightarrow 曹县 → \rightarrow 学校”有狗,并且“家 → \rightarrow 学校”有狗 ⌋ \rfloor ,对应于等式的左侧和等式的右侧。

    缓解所谓的“梯度消失”

    k i = ∂ L ∂ h i k_i=\frac{\partial L}{\partial h_i} ki=hiL ,将处理后的式 ( ∗ ∗ ) (**) () 展开,得
    ∂ L ∂ W h f ⟨ t ⟩ = c t − 1 σ ′ ( ⋅ ) h t − 1 [ ( k 1 o 1 ) + ( k 2 o 2 f 2 ) + ( k 3 o 3 f 3 f 2 ) + ⋯ + ( k T o T f T … f 3 f 2 ) ] \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}}= c_{t-1}\sigma'(·)h_{t-1} \left[ (k_1o_1)+(k_2o_2f_2) + (k_3o_3f_3f_2) + \dots + (k_To_Tf_T\dots f_3f_2) \right] WhftL=ct1σ()ht1[(k1o1)+(k2o2f2)+(k3o3f3f2)++(kToTfTf3f2)]
    其中, f i f_i fi 为 sigmoid 函数,通过监督训练,这些函数的取值将起到“门”的作用,即非 0 0 0 1 1 1 。上式中显然不存在激活函数导数连乘的形式,这降低了梯度消失发生的可能,另外还通过多个 sigmoid 函数连乘实现对远距离的信息进行筛选,弥补了 RNN 无法解决长期依赖的问题。

2.4. 训练过程

根据上面的动态传播过程图我们知道,每个 LSTM 单元的四个神经网络(结构图中的黄色部件)的输入都是向量 h t − 1 h_{t-1} ht1 x t x_t xt 经过拼接(concatenate)后的向量,输出到下一个单元的向量为 h t h_t ht,当然,这里无需考虑 c t − 1 c_{t-1} ct1,因为 c t − 1 c_{t-1} ct1 不经过神经网络,也就不存在维度变化。假设 h t − 1 h_{t-1} ht1 h i d d e n _ s i z e \rm hidden\_size hidden_size 维向量, x t x_t xt x _ s i z e \rm x\_size x_size 维向量,每个神经网络的输出均为 h i d d e n _ s i z e \rm hidden\_size hidden_size 维向量,相当于将 h i d d e n _ s i z e + x _ s i z e \rm hidden\_size+x\_size hidden_size+x_size 维向量映射到 h i d d e n _ s i z e \rm hidden\_size hidden_size 维向量,所以每个神经网络对应的参数可以表示为 ( h i d d e n _ s i z e + x _ s i z e , h i d d e n _ s i z e ) (\rm hidden\_size+x\_size,\rm hidden\_size) (hidden_size+x_size,hidden_size) 的矩阵。四个神经网络,将 h i d d e n _ s i z e + x _ s i z e \rm hidden\_size+x\_size hidden_size+x_size 维向量映射到 4 × h i d d e n _ s i z e \rm 4\times hidden\_size 4×hidden_size 维向量,一个 LSTM 单元完整的参数矩阵为 ( h i d d e n _ s i z e + x _ s i z e , 4 × h i d d e n _ s i z e ) (\rm hidden\_size+x\_size,4\times hidden\_size) (hidden_size+x_size,4×hidden_size)。由于 LSTM 每个单元共享参数矩阵,所以整个 LSTM 的参数矩阵即为 ( h i d d e n _ s i z e + x _ s i z e , 4 × h i d d e n _ s i z e ) (\rm hidden\_size+x\_size,4\times hidden\_size) (hidden_size+x_size,4×hidden_size)。注意,将 4 4 4 个神经网络对应的参数矩阵合并只是为了进行矩阵乘法时更简便,所以计算完之后还是要拆开,再进行不同的运算。

举个简单的例子,训练 b a t c h _ s i z e = 64 \rm batch\_size=64 batch_size=64 的一组语句,每个语句 20 20 20 个词,每个词向量 200 200 200 维,隐藏层向量 h t h_t ht 128 128 128 维, c t c_t ct h t h_t ht 同维。LSTM 的输入张量为 ( 64 , 20 , 200 ) (64, 20, 200) (64,20,200),LSTM 的参数矩阵为 ( 128 + 200 , 4 × 128 ) (128+200,4\times 128) (128+200,4×128)。对于某一个 LSTM 单元来说,输入为 ( 64 , 200 ) (64, 200) (64,200) 的矩阵,和 h t h_t ht 拼接得到 ( 64 , 200 + 128 ) (64, 200+128) (64,200+128),输入矩阵与参数矩阵相乘得到 ( 64 , 4 × 128 ) (64,4\times 128) (64,4×128),即每个神经网络的输出为 ( 64 , 128 ) (64, 128) (64,128)。神经网络的输出会进行一些不影响矩阵维度的位操作,所以该单元输出的 c t c_t ct h t h_t ht 仍然为 ( 64 , 128 ) (64,128) (64,128) 的矩阵。每个单元都重复进行相同的操作, 20 20 20 次操作(时间步)后,最终全部单元的输出为 ( 20 , 64 , 128 ) (20,64, 128) (20,64,128) 的矩阵。

如此我们得到了 LSTM 的输出矩阵为 ( t i m e _ s t e p , b a t c h _ s i z e , h i d d e n _ s i z e ) \rm(time\_step, batch\_size, hidden\_size) (time_step,batch_size,hidden_size)。根据下游任务的不同,会定义不同的损失函数,比如分类任务,那么我们仅会选择最后一个时刻的这批样本的交叉熵作为损失函数;当然,对于其他的一些任务,也可以选择对全部时刻的交叉熵进行加和或求均值作为最终的损失函数。

这里我们仅讲解将 LSTM 最后一个单元(时刻)输出结果的交叉熵作为损失函数,其他情况类似。假设全部单词数为 v o c a b u l a r y _ s i z e \rm vocabulary\_size vocabulary_size,我们需要先定义一个可训练的矩阵,大小为 ( h i d d e n _ s i z e , v o c a b u l a r y _ s i z e ) \rm (hidden\_size, vocabulary\_size) (hidden_size,vocabulary_size),作用是将 LSTM 最后一个单元的输出为 ( b a t c h _ s i z e , h i d d e n _ s i z e ) \rm (batch\_size, hidden\_size) (batch_size,hidden_size) 的矩阵映射到大小为 ( b a t c h _ s i z e , v o c a b u l a r y _ s i z e ) \rm (batch\_size, vocabulary\_size) (batch_size,vocabulary_size) 的矩阵上。这样,矩阵的每一行代表一个样本(单词),按行 softmax 后每行均为概率分布。每个样本根据对应的独热“标签”计算对应的交叉熵,再将 b a t c h _ s i z e \rm batch\_size batch_size 个交叉熵加和或者求均值作为目标函数。采用梯度下降等方法对模型参数进行更新。

注意区别 softmax 和交叉熵。softmax 只是一种将一般向量化为同维概率分布的手法,而交叉熵则是一种将两组概率分布变为标量的计算。

LSTM 作为语言模型,任务是根据输入的若干个单词预测下一个单词。因此,每个 LSTM 的“标签”是该条输入语句当前单词的下一个单词对应的独热编码。对于单词处于语句末尾的情况,一般会在句末引入特殊的语句结束符号;还有一些其他的与具体实现有关的特殊情况,在这里不详细展开。

REF

[1] Understanding LSTM Networks - colah’s blog

[2] LSTM神经网络详解 - CSDN博客

[3] 详解LSTM - 知乎 - 仅参考图片

[4] 《神经网络的梯度推导与代码验证》之LSTM的前向传播和反向梯度推导 - 博客园

[5] 4.RNN梯度消失回顾(公式推导)- bilibili

[6] LSTM如何来避免梯度弥散和梯度爆炸? - 知乎 - 用户Quokka的回答

[7] LSTM如何解决RNN带来的梯度消失问题 - CSDN博客

[8] LSTM训练过程与参数解读 - CSDN

[9] 使用LSTM实现语言模型 - 知乎

[10] 关于LSTM的输入和训练过程的理解 - 博客园

[11] tf.nn.dynamic_rnn的输出outputs和state含义 - CSDN

[12] tf.nn.softmax_cross_entropy_with_logits函数 - CSDN

[13] LSTM每一个时间步都有一个损失函数吗? - 知乎

你可能感兴趣的:(【机器学习】,lstm,人工智能,rnn,自然语言处理)