第一是实现多层的LSTM的网络;
第二是实现两个LSTM的state的concat操作, 分析 state 的结构.
对于第一个问题,之前一直没有注意过, 看下面两个例子:
在这里插入代码片
import tensorflow as tf
num_units = [20, 20]
#Unit1, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = [tf.nn.rnn_cell.BasicLSTMCell(num_units=units) for units in num_units]
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)
#Unit2, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = []
# for i in range(2):
# multi_rnn.append(tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units[i]))
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)
# Unit3 *********ERROR***********
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])
# single_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=20) # same as below
lstm_cells = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(num_units=20)] * 2)
output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)
print(output)
print(state)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for var in tf.global_variables():
print(var.op.name)
output_run, state_run = sess.run([output, state])
之前还真没注意到这个问题, 虽然一般都是多层的维度一致,但是都是写成 Unit2 这种形式.
第二个问题两个 Encoder 的 State 的融合, 并保持 State 类型 (LSTM/GRU)
import tensorflow as tf
def concate_rnn_states(num_layers, encoder_state_local, encoder_state_global):
'''
:param num_layers:
:param encoder_fw_state:
For LSTM:
(LSTMStateTuple(c=,
h=),
LSTMStateTuple(c=,
h=))
For GRU:
(,
)
:param encoder_bw_state: same as fw
:return: tuple
'''
encoder_states = []
for i in range(num_layers):
if isinstance(encoder_state_local[i], tf.nn.rnn_cell.LSTMStateTuple):
# for lstm
encoder_state_c = tf.concat(values=(encoder_state_local[i].c, encoder_state_global[i].c), axis=1,
name="concat_layer{}_state_c".format(i))
encoder_state_h = tf.concat(values=(encoder_state_local[i].h, encoder_state_global[i].h), axis=1,
name="concat_layer{}_state_h".format(i))
encoder_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
elif isinstance(encoder_state_local[i], tf.Tensor):
# for gru and rnn
encoder_state = tf.concat(values=(encoder_state_local[i], encoder_state_global[i]), axis=1,
name='GruOrRnn_concat')
encoder_states.append(encoder_state)
return tuple(encoder_states)
num_units = [20, 20]
#Unit1
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])
with tf.variable_scope("encoder1") as scope:
local_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
local_lstm_cells = tf.contrib.rnn.MultiRNNCell(local_multi_rnn)
encoder_output_local, encoder_state_local = tf.nn.dynamic_rnn(local_lstm_cells, X, time_major=False, dtype=tf.float32)
with tf.variable_scope("encoder2") as scope:
global_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
global_lstm_cells = tf.contrib.rnn.MultiRNNCell(global_multi_rnn)
encoder_output_global, encoder_state_global = tf.nn.dynamic_rnn(global_lstm_cells, X, time_major=False, dtype=tf.float32)
print("concat output")
encoder_outputs = tf.concat([encoder_output_local, encoder_output_global], axis=-1)
print(encoder_output_local)
print(encoder_outputs)
print("concat state")
print(encoder_state_local)
print(encoder_state_global)
encoder_states = concate_rnn_states(2, encoder_state_local, encoder_state_global)
print(encoder_states)