RNNCell是TensorFlow中发RNN基本单元。
本身是一个抽象类,拥有两个子类,一个是BasicRNNCell,另一个是BasicLSTMCell。
(注:RNNCell:是抽象类不能进行实例化,可以使用它的子类 BasicRNNCell 或BasicLSTMCell 进行实例化,得到 cell )
RNNCell的三个要点
call方法,所有RNNCell的子类都会实现一个call函数。利用call函数可以实现RNN的单步计算。
对于一个已经实例化好的基本单元cell 调用形式为:
(output, next_ state) = cell.call(input, state)
RNNCell的类属性state_size和output_size分别规定了隐层的大小和输出向量的大小。
通常是以batch形式输入数据,input的形状为(batch_size,input_size),
调用call函数时对应的隐层的形状是(batch_size,state_size),
输出的形状是(batch_size,output_size)。
在TensorFlow中定义一个基本的RNN单元的方法为:
import tensorflow as tf
rnn_cell= tf.nn.rnn_cell.BasicRNNCell(num units=128)
print(rnn_cell.state_size ) #打出 state_size 着一下,应有 state_size = 128
在TensorFlow中定义一个LSTM基本单元的方法为:
import tensorflow as t f
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num units=128)
print(lstm_cell.state_size) #state_size = LSTMStateTuple(c=l28, h=128)
LSTM可以看做有h和C两个隐层。在TensorFlow中LSTM基本单元的state_size由两部分组成,一部分是c,另一部分是h。
具体使用时,可以通过state.h以及state.c进行访问,下面是一个示例代码:
import tensorflow as i::f
import numpy as np
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size
#通过zero_state方法得到一个全0的初始化状态
hO = lstm_cell.zero_state(32, np.float32)
#调用call方法实现单步计算
output, hl = lstm_cell.call(inputs, hO)
#查看h1的状态
print(hl.h) # shape=(32, 128)
print(hl.c) # shape=(32, 128)
堆叠RNN : MultiRNNCell
单层RNN能力有限,需要多层RNN。
将x输入到第一层RNN后得到隐层状态h,这个隐层状态相当于第二层RNN的输入,第二层RNN的隐层状态又相当于第三层RNN的输入,以此类推。三层RNN串联
在TensorFlow中,使用tf.nn.rnn_cell.MultiRNNCell函数对RNN进行堆叠,代码如下:
import tensorflow as tf
import numpy as np
#每次调用这个函数返回一个BasicRNNCell
def get_a_cell():
return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
#用tf.nn.rnn _cell_MultiRNNCell创建三层RNN
cell= tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])
#得到的cell实际也是RNNCell的子类
#它的state_size是(128,128,128)代表3个隐层状态,每个隐层状态是128
print(cell.state_size) # (128, 128, 128)
#使用对应的call函数
inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch size
#通过zero_state方法得到一个全0的初始状态
hO = cell.zero_state(32, np.float32)
output, hl = cell.call(inputs, hO)
print(hl) # tuple 中合有 3 个 32xl28 的向噩