由浅入深理解 RNN

本篇文章由浅入深地介绍了RNN的模型,适合有一定机器学习基础,想由浅入深地理解RNN的同学。
本文目前写了主要框架,遇到不理解的地方,建议打开参考文献进行扩展阅读。

对于自然语言学习来说,一种典型的模型是:给定前面的一系列单词,预测接下来最有可能的单词是什么。
比如 :

我 昨天 上学 迟到 了 ,老师 批评 了 ____。

传统的NLP使用N-gram模型来预测,前面N个词影响当前位置的预测结果(此案例中 要向前包含到“我”这个单词,才能推出此处的结果是 “我”),但是如果我们想处理任意长度的句子,N设为多少都不合适;另外,模型的大小和N的关系是指数级的,4-Gram模型就会占用海量的存储空间。
RNN解决的就是此类问题,因为RNN理论上可以往前看(往后看)任意多个词。【1】

一、RNN模型的理解

首先从全连接模型开始

如下所示是一组典型的全连接网络,每一个全连接网络由输入层Xt,隐藏层St和输出层Ot组成

每个单独的全连接网络,代表相应的时刻的输入(在NLP中就是相应位置的单词)和相应输出之间的关系,但是这个全连接网络组是没有考虑到“历史时刻”的信息。

由浅入深理解 RNN_第1张图片
图1

所以RNN网络,将前一时刻的隐藏层,一起作为当前时刻隐藏层的输入,如图2所示。

前一时刻的隐藏层S(t-1),经过一个参数矩阵W相乘作为S(t)的输入,这样逐步递推,所以任意时刻的输入参数都包含了所有的历史信息。

由浅入深理解 RNN_第2张图片
图2

在NLP中,每个输入Xt代表一个字符(单词),Ot代表相应的推断结果。

在实际实现的时候,不会重复写这么多网络,会用一个网络的循环来模拟上述结构,即为Recurrent Nerual Network。这是因为在RNNs中,每输入一步,每一层,各自都 共享参数。这主要是说明RNN中的每一步都在做相同的事,只是输入不同,因此大大地降低了网络中需要学习的参数【4】。

同时,RNN中每一步都会有输出,但是每一步的输出并不是必须的。比如,我们需要预测一个步态序列属于哪一个人,我们仅仅需要关心最后一个步态输入后的输出,而不需要知道每个步态输入后的输出。同理,每步的输入也不是必须的【4】。所以有以下各种的RNN网络【2】

由浅入深理解 RNN_第3张图片
图3

上述模型的解释为:(1)普通的神经网络(2)序列输出(图片->文字解释)(3)序列输入(文本分类)(4)序列输入与序列输出:机器翻译(5)同步序列输入与输出:视频分类中每一帧的解释

以“预测下一个字符”作为案例,字符级别的RNN预测网络如下【2】

由浅入深理解 RNN_第4张图片
图4

二、RNN的推导

基本的RNN的推导,基于神经网络的BP算法,称之为BPTT,即沿着每一个时间点倒退,如下图所示【5】。纯粹的BPTT在实际中并不实用,当t很大的时候,就出现了梯度消失的现象,所以后面重点介绍几个常用的RNN变形及其推导方式。

由浅入深理解 RNN_第5张图片
图5

三、RNN的实用化变形:LSTM,GRU

LSTM和GRU是常见的两种RNN cell,可以避免原始的RNN的长期依赖问题

主要的思想是训练是否忘记“历史时刻”的输入,下图分别表示LSTM和GRU【6】

深刻理解LSTM: http://colah.github.io/posts/2015-08-Understanding-LSTMs/

深刻理解GRU:http://blog.csdn.net/meanme/article/details/48845793

总体来说 LSTM和GRU的训练效果都差不多,GRU往往更快一些

由浅入深理解 RNN_第6张图片
图6

参考文献

【1】http://blog.csdn.net/qq_23225317/article/details/77834890

【2】http://karpathy.github.io/2015/05/21/rnn-effectiveness/

【3】https://colah.github.io/posts/2015-08-Understanding-LSTMs/

【4】https://www.zybuluo.com/Duanxx/note/545194

【5】http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

【6】http://blog.csdn.net/meanme/article/details/48845793

你可能感兴趣的:(由浅入深理解 RNN)