循环神经网络系列(一) Tensorflow中BasicRNNCell
循环神经网络系列(二)Tensorflow中dynamic_rnn
经过前面两篇博文,我们介绍了如何定义一个RNN单元,以及用dynamic_rnn
来对其在时间维度(横轴)上展开。我们今天要介绍的就是如何叠加多层RNN单元(如双向LSTM),同时对其按时间维度展开。具体多层RNN展开长什么样呢?还是用最直观的图来展示,如下所示:
其中A,B
分别表示两个RNN单元,然后再分别对其按时间维度time_step=3
进行展开,最终形成了两层,包含两个状态和3个输出。要完成这样一个例子,在Tensorflow中该如何来实现呢?
1. 先定义两个RNN单元
def get_a_cell(output_size):
return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])
经过上面的8行代码,我们就定义好了两个堆叠在一起的RNN单元A和B,如下图所示:
2. 利用dynamic_rnn
进行展开
import tensorflow as tf
def get_a_cell(output_size):
return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
print(outputs)
print(final_state)
>>
Tensor("rnn/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4, 5), dtype=float32)
(<tf.Tensor 'rnn/while/Exit_2:0' shape=(4, 5) dtype=float32>, <tf.Tensor 'rnn/while/Exit_3:0' shape=(4, 5) dtype=float32>)
从第23行结果可知,输出的最后状态有两个,形状分别都是shape=(4,5)
,这也符合我们的预期;而第22行的输出结果shape=(3,4,5)
有表示什么意思呢?这里的3就不表示维度了,而表示输出结果有3部分,每个部分的大小都是shape=(4,5)
,这也是我们所预期的。并且B层的final_state应该使等于第三个输出的。
3. 喂个实例跑跑
import tensorflow as tf
import numpy as np
def get_a_cell(output_size):
return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
print(outputs)
print(final_state)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x1
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x2
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]]) # x3
sess = tf.Session()
sess.run(tf.global_variables_initializer())
a, b = sess.run([outputs, final_state], feed_dict={inputs: X})
print('outputs:')
print(a)
print('final_state:')
print(b)
>>
Tensor("rnn/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4, 5), dtype=float32)
(<tf.Tensor 'rnn/while/Exit_2:0' shape=(4, 5) dtype=float32>, <tf.Tensor 'rnn/while/Exit_3:0' shape=(4, 5) dtype=float32>)
outputs:
[[[-0.6958626 -0.6776572 0.15731043 -0.6311886 0.20267256]
[ 0.07732188 0.09182965 -0.49770945 0.0051106 0.23445603]
[-0.304461 -0.2706095 -0.4083268 -0.3364025 0.26729658]
[-0.38100582 -0.35050285 -0.2153194 -0.3686508 0.21973696]]
[[-0.38028494 -0.39984316 0.5924934 -0.7433707 0.45858386]
[ 0.15477817 0.06120307 -0.23038468 -0.2532196 0.19319542]
[-0.09605556 -0.23243633 0.18608333 -0.6444844 0.34893066]
[-0.15772797 -0.2529126 0.32016686 -0.6125384 0.33331177]]
[[-0.45718285 -0.20688602 0.66812176 -0.81284994 -0.03955056]
[ 0.16529301 0.2245452 -0.45850635 -0.36383444 0.18540041]
[-0.0918629 0.11388774 0.01027385 -0.7402484 0.06189062]
[-0.21528585 0.00840321 0.20390712 -0.71303254 0.04809263]]]
final_state:
(array([[ 0.01885682, 0.79334605, -0.99330646, -0.19715786, 0.8772415 ],
[-0.43402836, -0.2537776 , -0.52755517, 0.5360404 , -0.38291538],
[-0.49418357, 0.28655267, -0.91146743, 0.4856847 , 0.22705963],
[-0.3087254 , 0.42241457, -0.8743213 , 0.26078507, 0.3464944 ]],
dtype=float32),
array([[-0.45718285, -0.20688602, 0.66812176, -0.81284994, -0.03955056],
[ 0.16529301, 0.2245452 , -0.45850635, -0.36383444, 0.18540041],
[-0.0918629 , 0.11388774, 0.01027385, -0.7402484 , 0.06189062],
[-0.21528585, 0.00840321, 0.20390712, -0.71303254, 0.04809263]],
dtype=float32))
可以看到output有3个部分,final_state有2个部分,且output的第三个结果和final_state的第二个结果相同,符合我们上面的猜想。
注意:
如果每层的输出大小要不同的话,直接在定义多层单元的时候填上不同的参数即可!
output_size = [5, 6]
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(size) for size in output_size])