在DNN和CNN中,训练样本的输入和输出往往都是确定的,并且对单个样本前后之间的关系不关心。这就导致DNN和CNN不好解决训练样本输入是连续的序列,且序列的长短不一,比如基于时间的序列:一段段连续的语音,一段段连续的文字。这些序列比较长,且长度不一,比较难拆分成一个个独立的样本来通过DNN/CNN进行训练,并且序列前后之间往往有很大的关系。而这正是RNN比较擅长的任务。先晒一张大家经常看到的图。
如果现在看不懂上图,没关系,现在只需要有一个大致概念,下文还会深入讲解。图1是一张动态RNN(关于动态和静态先不要关心,下文会再讲解)的结构图,因为动态RNN使用的比较多,所以看到的RNN结构图几乎都是动态的。可以看到上图有一个输入序列:Xt-1、Xt、Xt+1。这里的X常常是一个一维的张量,比如一张图片的一行像素值、一条语音的一个采样值等。Xt 输入以后首先与U计算(矩阵乘法)得到一个值,前一个序列的状态值St-1与W计算(矩阵乘法)得到另一个值,两个值组合后再与不同的权重值计算分别得到本次的O和S。O即是输出,S作为下一个序列的状态输入,从而保证不断地把前面的信息传递给后面,RNN相对于CNN最特殊的地方就在于这个状态S。上图可以看到三个U和W都是一样的,在动态RNN中,所有单元的U确实是一样的,共用权重,所以经常可以看到图1左边的表示方法。例如,一张 10*20 的灰度图片,假设每一行是一个输入序列,即每个序列长都是20,总共10个序列,则这里的X就是一个长度为10的张量,X1~X10都会共用同一套权重计算。这里还有一个问题,就是从第二个输入序列开始都有前一个计算的状态S,那么第一个的S是哪里来的呢?常用的做法是默认为0,当然也支持在定义网络的时候定义一个初始值。
如果这里你已经理解了RNN的大致原理,那么看下面双向RNN的结构图就应该很容易理解了。在双向RNN中,不仅仅有一个从前往后的状态传递,而且还有一个从后往前的状态传递,在实际应用中,也比较容易理解,例如一句话:我现在正在()吃饭,对于前面的空缺往往可以从后面的信息获取,这里空缺的部分是“餐厅”之类的词语可能性更大。
在RNN中还存在另一种结构----多层RNN,和多个卷积层一样,前一个RNN的输出作为下一层RNN的输入,这里使用h表示上面的状态s,注意这里多层RNN之间并不共享权重。
RNN虽然解决了序列之间的依赖问题,但也仅限于简单的逻辑和样本。对于复杂的问题,激活函数的损失值在传递的过程中,不仅要在层与层之间传播,而且在每一层的样本序列之间也要传播,这就导致随着层数的增加,损失值的传递会越来越弱,所以RNN无法学习太长的序列特征。于是神经网络又演化了许多RNN的变体版本,LSTM正式其中之一。
LSTM(长短记忆的时间递归神经网络Long Short Term Memory)是一种特殊的RNN,它可以学习长期依赖信息。先放一张LSTM的结构图。
图4中每条线表示数据传输,圆圈代表计算操作,x表示矩阵点乘,+表示加法。合在一起的线表示向量的拼接(concat),分开的线表示内容被复制,橘黄色框中表示不同的激活函数。相比RNN只有一个传递状态不同,LSTM有两个传输状态(cell state)和(hidden state),相当于状态。图4中Xt输入后,首先与拼接(concat)为新的向量,乘以不同的权重后再通过不同的激活函数输出。
LSTM的核心思想引入了一个叫细胞状态的连接,这个细胞状态用来存放想要记忆的东西,同时在里面加入了三个门(三个sigmoid激活函数)。
看到这里应该对LSTM的整个结构和计算流程非常清楚了。关于参数个数和输出格式会在下文实践阶段详细介绍,本节只要理解整个数据流动过程即可。
在tensorflow中使用过RNN的同学应该知道,定义好RNN单元cell之后,还需要将它们连接起来构成RNN网络。Tensorflow提供了构建静态RNN、动态RNN以及双向RNN的API(更多API请参考TF官网)。
在单层、多层和双向RNN中都有动态、静态之分。静态的意思就是按照样本的时间序列个数(n)展开,在图中创建n个序列的cell;动态的意思就是只创建样本中一个序列的RNN,其他序列数据都会循环来进入RNN。参考图1,动态RNN只会创建一个cell,共享权重,静态RNN会创建3个cell,并且要求输入序列的个数必须是3。这就导致通过静态生成的RNN网络,生成时间更长、内存占用更多、导出的模型更大、使用非常不便,而通过动态生成的RNN网络,占用内存少、模型体积小,还能支持不同的序列个数,因此使用时基本都是动态RNN网络。
(1)动态RNN
import tensorflow as tf
import numpy as np
X = np.random.randn(2,4,5)
X[1,1:] = 0
seq_length = [4,1]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=3, state_is_tuple=True)
outpout, states = tf.nn.dynamic_rnn(cell,X,seq_length,dtype=tf.float64)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result, sta = sess.run([outpout,states])
print(result.shape)
print(result[0])
print(result[1])
print("--------------------------------------------------")
print(type(sta))
print(sta[0])
print(sta[1])
上述程序简单介绍了tf.nn.dynamic_rnn的使用,在tf.contrib.rnn.BasicLSTMCell中设置state_is_tuple=True,tf.nn.dynamic_rnn的返回就会把图4中的和放在一个tuple中输出,否则就会放在一个张量中。换种说法,如果state_is_tuple=True,则上述程序的sta类型是 tuple,如果为False,则sta的类型就是np.array。上述程序的输出结果如下:
(2, 4, 3)
[[-0.05057971 0.04302384 0.10817957]
[ 0.07582827 -0.24784242 0.06353555]
[ 0.02994161 -0.03758307 0.28416459]
[-0.0544189 -0.12191656 0.04009293]]
[[-0.05939422 0.07873275 0.08203859]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]]
--------------------------------------------------
[[-0.14131451 -0.17003618 0.15147233]
[-0.2507147 0.22557675 0.18003317]]
[[-0.0544189 -0.12191656 0.04009293]
[-0.05939422 0.07873275 0.08203859]]
从上面结果可以看出result的shape是(2,4,3),2表示batch_size;4表示序列个数,因为输入X的第二个维度就是4;3是因为程序中设置num_units=3。这里解释下num_units这个参数,num_units表示2中介绍的全连接W权重的第二个维度,上面X输入的序列长度是5,num_units=3表示输出是一个长度为3的一维向量。因此的结果就是长度为8的一维向量,则W的大小就是 8*3,四个橘黄色小框的参数个数就是 (8*3+3)*4。
再解释下tf.nn.dynamic_rnn的输出,第一个元素是RNN网络的输出output,也即图4中的集合,因为指定了第二个样本的序列个数是1,所以第二个的输出除第一行外都补0;第二个元素是状态states,states是一个tuple,不仅仅输出了,还输出了,第一个元素是,第二个元素是,可以发现sta[1]的值和result中的最后输出值是一样的。
(2)双向多层动态RNN
def lstm_layer(num_units, inputs):
stacked_fw_rnn = []
stacked_bw_rnn = []
for i in range(2):
stacked_fw_rnn.append(tf.contrib.rnn.LSTMCell(num_units=num_units))
stacked_bw_rnn.append(tf.contrib.rnn.LSTMCell(num_units=num_units))
mcellf = tf.contrib.rnn.MultiRNNCell(stacked_fw_rnn)
mcellb = tf.contrib.rnn.MultiRNNCell(stacked_bw_rnn)
output, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([mcellf],[mcellb],inputs,dtype=tf.float32)
return output
上面是自己以前写的的一个双向动态RNN函数,可以把range的参数提取到函数参数中,设置层数。比较简单,就不过多介绍了,注意返回值是一个元组,第一个元素是输出,第二个是正向状态值,第三个是反向状态值。
参考资料
https://zhuanlan.zhihu.com/p/32085405
https://www.zhihu.com/question/41949741
https://blog.csdn.net/zhaojc1995/article/details/80572098