本文通过简单的实验说明lstm中 state与output之间的关系
假设参数如下:
batch_size = 4 # 训练语料中一共有4句话
sequeue_len = 5 # 每句话只有5个词语
ebedding = 6 # 每个词语的词向量维度为 6
hidden_size = 10 # 神经元个数为10
(1)output说明
首先,比方说我们训练语料一共有4句话,每句话有5个词语,每个词语ebedding为6个维度,所以输入数据的
shape=[4,5,6]
然后,经过一个或者多个神经元为10的 cell,(多个cell也是串联的,所以最后结果也就只有一份)得到 output 和 state。
output shape = [4,5,10]
最后,output[:, -1, :] 我们取每句话中最后一个时刻(词语)的输出作为下一步的输入(相当与用最后一个时刻的输出来表示这句话),这样,就得到了 4 x 10 的矩阵。
(2)state说明
state 是个tuple(c, h)
state = LSTMStateTuple(c=array([4,10], dtype=float32), h=array([4,10], dtype=float32))
说明:每句话经过当前cell后会得到一个state,状态的维度就是隐藏神经元的个数,此时与每句话中包含的词语个数无关,这样,state就只跟 训练数据中包含多少句话(batch_size) 和 隐藏神经元个数(hidden size)有关了。
其中 c =[batch_size, hidden_size], h = [batch_size, hidden_size]
说明:经过多少个cell,就有多少个LSTMStateTuple,即每个cell都会输出一个 tuple(c, h)
(3)state 中的 h 跟output 的最后一个时刻的输出是一样的,即:
output[:, -1, :] = state[0].h
测试代码如下:
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
def get_a_cell():
return tf.nn.rnn_cell.BasicLSTMCell(num_units=10) #也可以换成别的,比如GRUCell,BasicRNNCell等等
# X = tf.random_normal(shape=[4, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
X = np.array([[[1,2,3,4,6],
[0,1,2,3,8],
[3,6,8,1,2],
[2,3,6,4,1]],
[[2,3,5,6,8],
[3,4,5,1,7],
[6,5,9,0,2],
[2,3,4,6,1]],
[[2,3,5,1,6],
[3,5,2,4,7],
[4,5,2,4,1],
[3,4,3,2,6]]])
X = tf.to_float(X)
stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(1)], state_is_tuple=True)
# initial_state = stacked_lstm.zero_state(5, tf.float32)
# output, state = tf.nn.dynamic_rnn(stacked_lstm, X, initial_state=initial_state, time_major=True)
output, state = tf.nn.dynamic_rnn(stacked_lstm, X, time_major=False, dtype=tf.float32)
last = output[:, -1, :] # 取最后一个时序输出作为结果
# fc_dense = tf.layers.dense(last, 10, name='fc1')
# fc_drop = tf.contrib.layers.dropout(fc_dense, 0.8)
# fc1 = tf.nn.relu(fc_drop)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(last)
print '-------------------------\n'
print sess.run(state[0].h)
运行结果:
[[ 0.0313809 -0.01724125 0.19791707 0.00551758 0.11339158 -0.09302805
0.23602076 0.68961358 0.01553136 0.36130285]
[-0.00354887 0.0148001 0.16432436 0.01897741 0.22235572 -0.08562981
0.14585365 0.75319165 -0.00278428 0.28128645]
[-0.00527866 -0.00108556 0.19138697 0.01754982 0.15600391 -0.04693945
0.02330696 0.91875976 0.12326637 0.33844444]]
-----------------------------------------------------------------------------------------
[[ 0.0313809 -0.01724125 0.19791707 0.00551758 0.11339158 -0.09302805
0.23602076 0.68961358 0.01553136 0.36130285]
[-0.00354887 0.0148001 0.16432436 0.01897741 0.22235572 -0.08562981
0.14585365 0.75319165 -0.00278428 0.28128645]
[-0.00527866 -0.00108556 0.19138697 0.01754982 0.15600391 -0.04693945
0.02330696 0.91875976 0.12326637 0.33844444]]
版权声明:知识需要传播,如有需要,请任意转载https://blog.csdn.net/xiaokang06/article/details/80235950