tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True): n_hidden表示神经元的个数,forget_bias就是LSTM们的忘记系数,如果等于1,就是不会忘记任何信息。如果等于0,就都忘记。state_is_tuple默认就是True,官方建议用True,就是表示返回的状态用一个元祖表示。这个里面存在一个状态初始化函数,就是zero_state(batch_size,dtype)两个参数。batch_size就是输入样本批次的数目,dtype就是数据类型。
例如:
import tensorflow as tf
batch_size = 4
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。
#如果是False,那么输入的第二个维度就是steps。
#如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的
#final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output))
print(sess.run(final_state))
输出:
[[[-0.17837059 0.01385643 0.11524696 -0.04611184 0.05751593 -0.02275656
0.10593235 -0.07636188 0.12855089 0.00768109]
[ 0.07553699 -0.23295973 -0.00144508 0.09547552 -0.05839045 -0.06769165
-0.41666976 0.3499622 -0.01430317 -0.02479473]
[ 0.08574327 -0.05990489 0.06817424 0.03434218 0.10152793 -0.10594042
-0.25310516 0.07232092 0.064815 0.0659876 ]
[ 0.15607212 -0.31474397 -0.06477047 -0.06982201 -0.05489461 0.0188695
-0.30281037 0.39494631 -0.05267519 -0.03253869]]
[[-0.03209484 -0.06323308 -0.25410452 -0.10886975 0.00253956 -0.08053195
0.18729064 -0.0788438 0.14781287 -0.20489833]
[ 0.3164973 -0.10971865 -0.35004857 -0.00576114 -0.08092841 0.00883496
-0.17579219 0.19092172 -0.0237403 -0.43207553]
[ 0.2409949 -0.17808972 -0.1486263 0.02179234 -0.21656732 0.0522153
-0.21345614 0.18841118 -0.0094095 -0.34072629]
[ 0.12034108 -0.23767222 0.03664704 0.13274716 -0.04165298 -0.04095407
-0.31182185 0.36334303 -0.01146755 0.05028744]]
[[-0.12453001 -0.1567502 -0.16580626 -0.03544752 0.06869993 0.09097657
-0.02214662 -0.18668351 0.06159507 -0.35843855]
[ 0.2010586 0.03222289 -0.31237942 0.01898964 -0.08158109 -0.02510365
0.02967031 0.12587228 -0.22250202 -0.08734316]
[ 0.14316584 0.02029586 -0.1062321 0.02968353 -0.02318866 0.07653226
-0.13600637 -0.00440343 0.07305693 -0.26385978]
[ 0.23669831 -0.13415271 -0.10488234 0.03128149 -0.11343875 -0.05327768
-0.22888957 0.17797095 -0.02945257 -0.18901967]]]
LSTMStateTuple(c=array([[-0.72714508, 0.32974839, 0.67756736, 0.11421457, 0.39167076,
0.31247479, 0.0755761 , -0.62171376, 0.58582318, -0.19749212],
[ 0.44815305, 0.06901363, -0.88840145, 0.22841501, 0.04539755,
0.17472507, -0.50547051, 0.46637267, -0.07522876, -0.80750966],
[-0.19392423, -0.16717091, -0.19510591, -0.48713976, -0.18430954,
0.1046299 , 0.30127296, -0.03556332, -0.37671563, -0.1388765 ],
[-0.47982571, 0.2172934 , 0.56419176, 0.15874679, 0.29927608,
0.16362543, 0.11525643, -0.47210076, 0.56833684, -0.18866351]], dtype=float32), h=array([[-0.36339632, 0.17585619, 0.29174498, 0.03471305, 0.2237694 ,
0.13323013, 0.03002708, -0.26190156, 0.28289214, -0.12495621],
[ 0.1543802 , 0.04264591, -0.27087522, 0.084597 , 0.01555507,
0.10631134, -0.23696639, 0.2758382 , -0.03724022, -0.4389703 ],
[-0.14088678, -0.10961234, -0.10831701, -0.19923639, -0.10324109,
0.04290821, 0.10720341, -0.01477169, -0.14518294, -0.04280116],
[-0.34502122, 0.10841226, 0.32169446, 0.03053316, 0.20867576,
0.04689977, 0.03286072, -0.11068864, 0.37977526, -0.12110116]], dtype=float32))