RNN01

RNN01_第1张图片
RNN01_第2张图片
RNN01_第3张图片
RNN01_第4张图片
from tensorflow.python.ops import rnn, rnn_cell
def RNN(x, weights, biases):
    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
    # Permuting batch_size and n_steps
    x = tf.transpose(x, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    x = tf.reshape(x, [-1, n_input])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.split(0, n_steps, x)
    # Define a lstm cell with tensorflow
    ''' lstm_size = n_hidden'''
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Get lstm cell output
    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out']
RNN01_第5张图片
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0, state_is_tuple=True)
if is_training and config.keep_prob < 1:
    lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
        lstm_cell, output_keep_prob=config.keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers, state_is_tuple=True)

其中每个小长方形就表示一个cell,每个cell中又是一个略复杂的结构,1个cell具有n个hidden units。cell结构(BasicLSTMCell)的时候首先需要定义一个最小的cell单元,也就是小长方形,需要提供的一个参数就是hidden_units_size。

你可能感兴趣的:(RNN01)