在上一篇博客中详细地介绍了tf.nn.dynamic_rnn这个函数的参数和作用,接下来就来介绍一下改参数的两个输出outputs和state的具体含义。
outputs和state的关系直接了当地说便是:
outputs是最后一层每个step的输出,states是每一层的最后那个step的输出。
一、先来看一下tf.nn.dynamic_rnn的定义:
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
上面的参数均为输入参数,这些参数的具体意义在上一个博客中均已介绍过了。
这里再次提及是因为tf.nn.dynamic_rnn函数的输出,与输入参数inputs和time_major关联紧密。先回顾一下inputs
inputs:LSTM的输入。默认格式是[batch_size,num_step,vector_size]
其中,batch_size是输入的这批数据的数量;num_step(也经常写作max_time)是这批数据序列最长的长度,也就是样本的序列长度(时间步长);vector_size(也经常写作input_size)是cell中神经元的个数,也是输入向量的维度。
二、outputs和state各自结构
outputs.:outputs是一个tensor
如果time_major==False,outputs形状为 [ batch_size, num_step, vector_size ](因为默认是False,所以通常都是这个形式)。
如果time_major==True,outputs形状为 [num_step, batch_size, vector_size](要求rnn输入与rnn输出形状保持一致)
state:state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, vector_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, vector_size ]。
为什么LSTM网络的state会多出一个2呢?
这是因为LSTM网络的特殊构造,如下图所示,LSTM网络有两个输出:Ct和ht。
其中,Ct是主干线信息,代表着经过LSTM的cell过滤后留下来的有用信息,仍要往后面传递;ht则是代表了当前这个LSTM的cell的最终输出。
Ct和ht共同构成了这个LSTM的state,所以会多出一个数字2,这个2就代表了Ct和ht这两项(PS:如果是GRU网络,state就只有一个,GRU将Ct 和 ht进行了简化,合并成了一个ht)。
对照着一开始就提到的outputs和state的关系(outputs是最后一层每个step的输出,states是每一层的最后那个step的输出),可知:
对LSTM网络,state是个tuple,代表和Ct和ht,并且ht与outputs中的对应的最后一个时刻的输出相等。
假设state形状为[ 2,batch_size, vector_size ],outputs形状为 [ batch_size, num_step, vector_size],那么state[ 1, batch_size, : ] == outputs[ batch_size, -1, : ](同理对于GRU,那么同理,state ==outputs[ -1 ])
还是通过代码示例具体演示一下
三、代码示例
import tensorflow as tf
import numpy as np
def dynamic_rnn(rnn_type='lstm'):
# 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度
X = np.random.randn(3, 6, 4)
# 第二个输入的实际长度为4
X[1, 4:] = 0
#记录三个输入的实际步长
X_lengths = [6, 4, 6]
rnn_hidden_size = 5
cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1, s1 = session.run([outputs, last_states])
print(np.shape(o1))
print(o1)
print(np.shape(s1))
print(s1)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')
得到如下输出(为了便于做标记,我截取的是图片形式):
如下图所示,input的形状为 [ 3, 6, 4 ],outputs的形状为 [ 3, 6, 5 ],state形状为 [ 2, 3, 5 ]。
state第一部分为c,代表cell state;第二部分为h,代表hidden state。其中ht与对应的outputs的最后一行(确切地说是最后一个不为0的部分)是相等的(用红框标识)。
参考:
https://blog.csdn.net/u010960155/article/details/81707498