机器学习面试题-为啥LSTM比RNN好

这里写自定义目录标题

  • 问题引入
    • 问题回答

问题引入

其实这算是个经典的问题了,在一般的只要你做过时间序列相关的项目或者理论的时候,LSTM和RNN的对比肯定是要问的。那两者到底有啥区别呢?

问题回答

其实对于这个问题,要从RNN在发展过程中带来的令人诟病的短处说起,RNN在train参数的时候,需要反向传播梯度,这个梯度是这么算的:
w i + 1 = w i − r ⋅ ∂ L o s s ∂ w ∣ w : w i , r > 0 w^{i+1}=w^{i}-r\cdot\frac{\partial{Loss }}{\partial{w}}|_{w:w^{i}},r>0 wi+1=wirwLossw:wi,r>0
其中 r r r是学习率, ∂ L o s s ∂ w ∣ w : w i \frac{\partial{Loss }}{\partial{w}}|_{w:w^{i}} wLossw:wi是损失函数在w处的导数,针对RNN在结构上很深的特征,会产生梯度消失和梯度爆炸,其中需要了解下什么是梯度消失和梯度爆炸,梯度消失指的是,RNN在某些 w i w^i wi取值上,导致梯度很小,梯度爆炸指的是, w i w^i wi在某些取值上,导致梯度特别大。如果你的学习率 r r r不变的话,那么参数要么几乎不变,要么就是变化剧烈,到时迭代动荡很难手收敛。通过我们对RNN的网络结构的建模,我们发现他的梯度是这个样子的:
∂ L t ∂ W h = ∑ t = 0 T ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ W h \frac{\partial{L_{t}}}{\partial{W^{h}}}=\sum_{t=0}^{T}{\sum_{k=0}^{t}{ \frac{\partial{L_{t}}}{\partial{y_t}} \frac{\partial{y_{t}}}{\partial{h_t}} (\prod_{j=k+1}^{t} \frac{\partial{h_{j}}}{\partial{h_{j-1}}} ) \frac{\partial{h_{k}}}{\partial{W^h}}}} WhLt=t=0Tk=0tytLthtyt(j=k+1thj1hj)Whhk
我们先不管这一大串公式是啥意思,大值得意思就是上面公式里面有依赖于时间 t t t的连乘符号;修正 t t t时刻的误差需要考虑之前的所有时间 k k k的隐藏层对时间 t t t的影响,当 k k k t t t距离越远,对应着隐含层之间的连乘次数就越多。就是这个连乘的结构产生了梯度消失,梯度爆炸也是它导致的。具体大一大波公式有需要看的话可以看下参考中的地(我只是搬运工)。
而LSTM(长短时记忆网络),因为可以通过阀门(gate,其实就是概率,共有输出、遗忘、输入三个阀门)记忆一些长期信息,所以,相比RNN,保留了更多长期信息(相应地也就保留了更多的梯度)。隐层之间的输入输出可以表示为:
c j = σ ( W f X j + b f ) c j − 1 + σ ( W i X j + b i ) σ ( W X j + b ) c_{j}=\sigma(W^fX_{j}+b^f)c_{j-1}+\sigma({W^iX_{j}}+b^i)\sigma(WX_{j}+b) cj=σ(WfXj+bf)cj1+σ(WiXj+bi)σ(WXj+b),于是连乘的项可以表示为:
∂ c j ∂ c j − 1 = σ ( W f X j + b ) \frac{\partial{c_{j}}}{\partial{c_{j-1}}}=\sigma(W^fX_{j}+b) cj1cj=σ(WfXj+b)
该值得范围在0-1之间,在参数更新的过程中,可以通过控制bais较大来控制梯度保持在1,即使通过多次的连乘操作,梯度也不会下降到消失的状态。所以,相比RNN,在LSTM上,梯度消失问题得到了一定程度的缓解。

更多内容,查看如下(百面机器学习):
机器学习面试题-为啥LSTM比RNN好_第1张图片

https://www.zhihu.com/question/44895610/answer/616818627
https://zhuanlan.zhihu.com/p/30844905
https://blog.csdn.net/laolu1573/article/details/77470889

你可能感兴趣的:(算法题,机器学习,深度学习)