Tensorflow中dynamic_rnn的用法

1 API接口

dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

2 举例说明

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

3 重要参数说明

cell:输入一个RNNcell实例

inputs:RNN神经网络的输入,如果 time_major == False (default), 输入的形状是: [batch_size, max_time, embedding_size],如果 time_major == True, 输入的形状是: [ max_time, batch_size, embedding_size].

initial_state: RNN网络的初始状态,网络需要一个初始状态,对于普通的RNN网络,初始状态的形状是:[batch_size, cell.state_size].

4 返回值

outputs: RNN网络的输出单元.
如果time_major == False (default), 输出单元的形状是: [batch_size, max_time, cell.output_size].
如果 time_major == True, 输出单元的形状是: [max_time, batch_size, cell.output_size].

state:RNN网络最终的状态,即RNN网络的最终输出。他的形状是[batch_size, cell.output_size]

你可能感兴趣的:(NLP)