最近在尝试实现一个简单的LSTMCell,源码中看似只是简单地调用一下:
tf.contrib.rnn.BasicLSTMCell()
实际上包含了很多没有弄明白地方。我想把这个学习过程完整地记录一遍。
首先,构建LSTM单元需要导入:
import tensorflow as tf
import numpy as np
上周的报告已经提到,LSTM单元中喂进的数据是一个3维数据,维度分别是input_size,batch_size,time_size。这里把X作为input的数据:
X=np.random.rand(3,6,4)
#batch_size=3,time_size=6,input_size=4
再次指明,input_size和Cell中的hidden_size有关,time_size则是处理一组数据的步长,batch_size则是用户自己选定的(通常开源文献中选为128、256等,从Memory中取出,再投喂给网络)。
为了便于观察Cell的输出结果,我们把X的第二个batch做以下处理:
X=[1,4:]=0
那么输入的三组batch中,每组的实际步长为:
X_length=[6,4,6]
X的输出可以参考如下:
[[[0.97220811 0.25908799 0.54227514 0.41574578]
[0.01295309 0.3510622 0.41254816 0.40131783]
[0.64841554 0.91885768 0.67117895 0.98121062]
[0.14896025 0.3912898 0.1417619 0.43468296]
[0.22438062 0.85157355 0.10037672 0.66274456]
[0.07133907 0.86983479 0.19161431 0.15118635]]
[[0.68270615 0.7659821 0.04970863 0.43649479]
[0.96759885 0.72994591 0.77564044 0.24077003]
[0.94344651 0.98036233 0.85772773 0.67501075]
[0.21152659 0.94275251 0.09053659 0.6004612 ]
[0. 0. 0. 0. ]
[0. 0. 0. 0. ]]
[[0.21548508 0.50614013 0.6444404 0.09635282]
[0.54535114 0.15882572 0.58684033 0.2026541 ]
[0.41272127 0.62597087 0.97968376 0.08931693]
[0.86418767 0.27609746 0.69480801 0.31376662]
[0.0309335 0.36077981 0.22935523 0.12807059]
[0.7778892 0.17223188 0.7626537 0.72124185]]]
#这样就对input的情况进行了一个很清晰的展示!
#time_size是每个batch的步长,但允许有空的位置出现,
这就是第二个batch中的5、6行被全部赋为0的原因。
先来看LSTM的内部构造:
这张图说明,LSTM内部有“四大块”,分别是使用sigmoid函数和tanh函数激活的部分,如下图。
每一块都是一个全连接的神经网络,那么hidden_size就是这个神经网络的每一层的节点数目(文献中指出,LSTM内部的神经网络可以有很多层,但每层的节点数目一般而言是一样的),换言之,内部输出的h_t、c_t的长度都是hidden_size,经过反馈到达输入端,和X_t(input_size)连接以后,权重参数的列数便是(hidden_size+input_size)。
理清了这几个超参数,我们可以写出这样的代码:
hidden_size=5
#创建LSTMcell
cell=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_size,state_is_tuple=True)
接下来让LSTM跑起来:
outputs,last_states=tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_length,
input=X
)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1,s1 = session.run([outputs,lats_states])
上面代码中出现了两个个十分重要的函数:tf.contrib.rnn.BasicLSTMCell()、tf.nn.dynamic_rnn()。前一个函数自然是根据我们提供的hidden_size来创建LSTM单元,后一个函数则是给这个LSTM单元input等参数。tf.nn.dynamic_run()有两个返回值,分别是outputs和last_states,这和LSTM的的结构是有关系的。
那么输出呢?知道了输入的各种超参数和过程,输出是什么形式?
我们已经提到,tf.nn.danamic_rnn()有两个返回值,outputs和last_states,理解LSTM的输出,需要从这两个tensor入手。
我们来看程序输入的outputs(o1):
#outputs维度是(3, 6, 5)
[[[ 0.10988916 -0.05924489 -0.00219612 0.03131131 0.11956187]
[ 0.14008177 -0.08764294 -0.00184445 0.06144539 0.18363025]
[ 0.09684535 -0.10180583 0.00621872 0.08553708 0.1787711 ]
[ 0.11151131 -0.13483347 -0.01405942 0.12652101 0.28426013]
[ 0.16092931 -0.13378875 0.01335965 0.16779374 0.31270015]
[ 0.1814057 -0.14386612 -0.01726051 0.10565656 0.32259905]]
[[ 0.06112274 -0.03399703 -0.01142925 0.01125745 0.07088148]
[ 0.09888767 -0.10105175 -0.02438729 0.10392439 0.20857589]
[ 0.13660149 -0.12307292 -0.02153195 0.12783929 0.23667747]
[ 0.13729271 -0.140286 -0.0181011 0.17082856 0.27899951]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]
[[ 0.04344821 -0.06923458 -0.02282487 0.07093523 0.08029628]
[ 0.11463346 -0.09463423 -0.01339233 0.08762536 0.15044823]
[ 0.13361562 -0.12923288 -0.00863851 0.15211888 0.26287943]
[ 0.18059539 -0.13309941 0.0330919 0.17870714 0.31354558]
[ 0.1644137 -0.16305717 -0.01865754 0.16931057 0.34752158]
[ 0.21349885 -0.14284222 0.02098966 0.16654242 0.35304967]]]
每个batch有6个单词(time_step),每个单词有4个字母(input_size),输入3个这样的batch到hidden_size=5(则cell.output_size=hidden_size=5)的全连接网络中,输出自然就是[batch_size,time_size,hidden_size],这样以拆解,就不难理解了outputs了!
那么lats_states呢?顾名思义,last_states就是最后一个状态!
来看输出的last_states(s1):
#lats_stats的维度(2,3,5)。
#这个输出是个tuple,整体上是一个(2,3,5)的数组,实际上是两个(3,5)的数组。
LSTMStateTuple
(c=array([[ 0.37597382, -0.43745685, -0.03624983, 0.28624145, 0.64269234],
[ 0.34083248, -0.38826329, -0.03365382, 0.32866948, 0.49783422],
[ 0.48956227, -0.52104725, 0.03657497, 0.37346697, 0.69955475]]),
h=array([[ 0.1814057 , -0.14386612, -0.01726051, 0.10565656, 0.32259905],
[ 0.13729271, -0.140286 , -0.0181011 , 0.17082856, 0.27899951],
[ 0.21349885, -0.14284222, 0.02098966, 0.16654242, 0.35304967]]))
什么意思呢?
last_states实际上输出的是上面这个单元中的H_t、C_t两个tensor!(LSTM的state是由C_t和 H_t组成的。)
可以看到,H_t是和outputs的最后一行相等的(H_t决定遗忘什么、记住什么)。C_t则是整个Cell的实际输出(DRQN中连接到后面的神经元节点)。
因此shape(last_states)=[2,batch_size,hidden_size]。
可以这样理解:outputs指明了每个batch上的每个time_step的每个input的输出,而last_states(=C_t + H_t)则表明了经过LSTM抽象处理后的最终结果。
最后附上这部分的完整代码:
import tensorflow as tf
import numpy as np
def dynamic_rnn(rnn_type='lstm'):
X=np.random.rand(3,6,4)
X[1,4:]=0
X_length=[6,4,6]
rnn_hidden_size=5
if(rnn_type=='lstm'):
cell=tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size,state_is_tuple=True)
else:
cell=tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
num=cell.output_size
outputs,last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_length,
inputs=X
)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1,s1 = session.run([outputs,last_states])
print(X)
print(np.shape(o1))
print(o1)
print(np.shape(s1))
print(s1)
print(num)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')