tensorflow中对lstm及双向lstm的理解

双向RNN(LSTM)的实现参考:
https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/bidirectional_rnn.py
需要注意的是,里面的static_bidirectional_rnn()函数是来自tf.contrib.rnn的,它的输入必须是list类型,而最新的官方接口可以直接使用tensor输入


关于基本的LSTM

构建一个rnn需要有两个关键东西,
cell,就是LSTM里面的一个模块;
网络,tf.nn.dynamic_rnn()、tf.nn.static_bidirectional_rnn()等


上面两个网络最后一般返回两个变量,一个是outputs,一个是state
1.
state是一个tuple(默认情况下),内容是(c,h),看LSTM的公式就知道,c就是细胞状态,h就是当前的输出
所以假设输入是[batch_size,steps,dim],中间4个门的神经元个数都是m,
那么c和h的shape都是[batch_size,dim,m],即没有steps这个维度,因为state是当前的最新状态,即经过steps步计算后的最终的状态
2.
outputs的shape是[batch_size,steps,dim,m]
outputs也是最终的输出,它包含了所有steps(视频的话理解为所有帧)的最终的输出

结论:
所以outputs[:,-1,:,:]的数值和state里h的数值是一样的
因为h就是最后step(最后帧)的最终输出


关于双向LSTM

原理可以参考:双向长短时记忆循环神经网络详解(Bi-directional LSTM RNN)

相关文献:《Bidirectional Recurrent Neural Networks》

在tensorflow的官方接口中,相关接口是tf.nn.static_bidirectional_rnn()
需要定义两个cell,分别用于序列的前向传播和后向传播
与一般的LSTM不同,该函数最后的输出有三个变量,outputs,state_fw,state_bw,每个变量的含义差不多和上面所述的一样
不一样的是,state_fw表示前向传播的最终状态,state_bw表示后向传播的最终状态
更要注意的是outputs,它表示的是所有steps(所有帧)的最终输出,但是这个输出是fw和bw相互concat的结果(一般的话,如上所述只有fw)

因此假设一般的LSTM比如tf.nn.dynamic_rnn()的outputs的shape是[batch_size,steps,dim,m],
那么tf.nn.static_bidirectional_rnn()的outputs的shape就是[steps,batch_size,dim*2,m](注意,是list类型),
知道Bidirectional Recurrent Neural Networks的原理就能理解,最后的输出是前向和后向输出的concat。

看tensorflow源代码也知道:
  flat_outputs = tuple(
      array_ops.concat([fw, bw], 1)
      for fw, bw in zip(flat_output_fw, flat_output_bw))
因此,outputs[:][:,0:dim]对应的是fw
outputs[:][:,dim:dim*2]对应的是bw
(注意,这里outputs是list类型,所以用[][]来分开索引)

由于fw和bw反向,因此state_fw的h对应的是outputs[-1][:,0:dim,:](最后一帧)
state_bw的h对应的是outputs[0][:,dim:dim*2,:](第一帧)



附上我的一些实验结果(我用的是convolution LSTM,而且是变长的序列),所以数据的shape会有些不一样

output:

tensorflow中对lstm及双向lstm的理解_第1张图片

state_fw(state_fw[1]就是前向的h,对应outputs[10][:,0:64]):

tensorflow中对lstm及双向lstm的理解_第2张图片

state_bw(state_bw[1]就是后向的h,对应outputs[0][:,64:128]):

tensorflow中对lstm及双向lstm的理解_第3张图片

以上3个是网络的输出。

下面进一步观察outputs里面的数据,

output[0]在axis=1维度上的切分:

tensorflow中对lstm及双向lstm的理解_第4张图片

output[10]在axis=1维度上的切分:

tensorflow中对lstm及双向lstm的理解_第5张图片


可见,结果是正确的。




你可能感兴趣的:(Python,Tensorflow,RNN)