1 API接口
dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
2 举例说明
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
3 重要参数说明
cell:输入一个RNNcell实例
inputs:RNN神经网络的输入,如果 time_major == False (default), 输入的形状是: [batch_size, max_time, embedding_size],如果 time_major == True, 输入的形状是: [ max_time, batch_size, embedding_size].
initial_state: RNN网络的初始状态,网络需要一个初始状态,对于普通的RNN网络,初始状态的形状是:[batch_size, cell.state_size].
4 返回值
outputs: RNN网络的输出单元.
如果time_major == False (default), 输出单元的形状是: [batch_size, max_time, cell.output_size].
如果 time_major == True, 输出单元的形状是: [max_time, batch_size, cell.output_size].
state:RNN网络最终的状态,即RNN网络的最终输出。他的形状是[batch_size, cell.output_size]