TF多层 LSTM 以及 State 之间的融合

第一是实现多层的LSTM的网络;
第二是实现两个LSTM的state的concat操作, 分析 state 的结构.

对于第一个问题,之前一直没有注意过, 看下面两个例子:

在这里插入代码片
import tensorflow as tf

num_units = [20, 20]

#Unit1, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = [tf.nn.rnn_cell.BasicLSTMCell(num_units=units) for units in num_units]
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

#Unit2, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = []
# for i in range(2):
#     multi_rnn.append(tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units[i]))
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

# Unit3 *********ERROR***********
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])
# single_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=20) # same as below
lstm_cells = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(num_units=20)] * 2)
output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

print(output)
print(state)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for var in tf.global_variables():
        print(var.op.name)
    output_run, state_run = sess.run([output, state])
   
 之前还真没注意到这个问题, 虽然一般都是多层的维度一致,但是都是写成 Unit2 这种形式.

第二个问题两个 Encoder 的 State 的融合, 并保持 State 类型 (LSTM/GRU)

import tensorflow as tf

def concate_rnn_states(num_layers, encoder_state_local, encoder_state_global):
    '''
    :param num_layers:
    :param encoder_fw_state:
    For LSTM:
    (LSTMStateTuple(c=,
        h=),
    LSTMStateTuple(c=,
        h=))
    For GRU:
    (,
        )
    :param encoder_bw_state: same as fw
    :return: tuple
    '''
    encoder_states = []
    for i in range(num_layers):
        if isinstance(encoder_state_local[i], tf.nn.rnn_cell.LSTMStateTuple):
            # for lstm
            encoder_state_c = tf.concat(values=(encoder_state_local[i].c, encoder_state_global[i].c), axis=1,
                                        name="concat_layer{}_state_c".format(i))
            encoder_state_h = tf.concat(values=(encoder_state_local[i].h, encoder_state_global[i].h), axis=1,
                                        name="concat_layer{}_state_h".format(i))
            encoder_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
        elif isinstance(encoder_state_local[i], tf.Tensor):
            # for gru and rnn
            encoder_state = tf.concat(values=(encoder_state_local[i], encoder_state_global[i]), axis=1,
                                      name='GruOrRnn_concat')

        encoder_states.append(encoder_state)
    return tuple(encoder_states)

num_units = [20, 20]

#Unit1
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])

with tf.variable_scope("encoder1") as scope:
    local_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
    local_lstm_cells = tf.contrib.rnn.MultiRNNCell(local_multi_rnn)
    encoder_output_local, encoder_state_local = tf.nn.dynamic_rnn(local_lstm_cells, X, time_major=False, dtype=tf.float32)

with tf.variable_scope("encoder2") as scope:
    global_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
    global_lstm_cells = tf.contrib.rnn.MultiRNNCell(global_multi_rnn)
    encoder_output_global, encoder_state_global = tf.nn.dynamic_rnn(global_lstm_cells, X, time_major=False, dtype=tf.float32)

print("concat output")
encoder_outputs = tf.concat([encoder_output_local, encoder_output_global], axis=-1)
print(encoder_output_local)
print(encoder_outputs)

print("concat state")
print(encoder_state_local)
print(encoder_state_global)
encoder_states = concate_rnn_states(2, encoder_state_local, encoder_state_global)
print(encoder_states)

你可能感兴趣的:(deep,learning,每天一点TF)