LSTM中state 与 output关系

本文通过简单的实验说明lstm中 state与output之间的关系

假设参数如下:

batch_size = 4 # 训练语料中一共有4句话
sequeue_len = 5 # 每句话只有5个词语
ebedding = 6 # 每个词语的词向量维度为 6
hidden_size = 10 # 神经元个数为10

(1)output说明

    首先,比方说我们训练语料一共有4句话,每句话有5个词语,每个词语ebedding为6个维度,所以输入数据的

shape=[4,5,6]

    然后,经过一个或者多个神经元为10的 cell,(多个cell也是串联的,所以最后结果也就只有一份)得到 output 和 state。

     output shape = [4,5,10]

    最后,output[:, -1, :] 我们取每句话中最后一个时刻(词语)的输出作为下一步的输入(相当与用最后一个时刻的输出来表示这句话),这样,就得到了 4 x 10 的矩阵。

(2)state说明
state 是个tuple(c, h)
state = LSTMStateTuple(c=array([4,10], dtype=float32),  h=array([4,10], dtype=float32))
说明:每句话经过当前cell后会得到一个state,状态的维度就是隐藏神经元的个数,此时与每句话中包含的词语个数无关,这样,state就只跟 训练数据中包含多少句话(batch_size) 和 隐藏神经元个数(hidden size)有关了。
     其中 c =[batch_size, hidden_size], h = [batch_size, hidden_size]

说明:经过多少个cell,就有多少个LSTMStateTuple,即每个cell都会输出一个 tuple(c, h)

(3)state 中的 h 跟output 的最后一个时刻的输出是一样的,即:
output[:, -1, :] = state[0].h

测试代码如下:

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
def get_a_cell():
    return tf.nn.rnn_cell.BasicLSTMCell(num_units=10) #也可以换成别的,比如GRUCell,BasicRNNCell等等

# X = tf.random_normal(shape=[4, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
X = np.array([[[1,2,3,4,6],
         [0,1,2,3,8],
         [3,6,8,1,2],
         [2,3,6,4,1]],
        [[2,3,5,6,8],
         [3,4,5,1,7],
         [6,5,9,0,2],
         [2,3,4,6,1]],
        [[2,3,5,1,6],
         [3,5,2,4,7],
         [4,5,2,4,1],
         [3,4,3,2,6]]])
X = tf.to_float(X)

stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(1)], state_is_tuple=True)
# initial_state = stacked_lstm.zero_state(5, tf.float32)
# output, state = tf.nn.dynamic_rnn(stacked_lstm, X, initial_state=initial_state, time_major=True)
output, state = tf.nn.dynamic_rnn(stacked_lstm, X, time_major=False, dtype=tf.float32)

last = output[:, -1, :]  # 取最后一个时序输出作为结果
# fc_dense = tf.layers.dense(last, 10, name='fc1')
# fc_drop = tf.contrib.layers.dropout(fc_dense, 0.8)
# fc1 = tf.nn.relu(fc_drop)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(last)

    print '-------------------------\n'

    print sess.run(state[0].h)

运行结果:

[[ 0.0313809  -0.01724125  0.19791707  0.00551758  0.11339158 -0.09302805
   0.23602076  0.68961358  0.01553136  0.36130285]
 [-0.00354887  0.0148001   0.16432436  0.01897741  0.22235572 -0.08562981
   0.14585365  0.75319165 -0.00278428  0.28128645]
 [-0.00527866 -0.00108556  0.19138697  0.01754982  0.15600391 -0.04693945
   0.02330696  0.91875976  0.12326637  0.33844444]]
-----------------------------------------------------------------------------------------

[[ 0.0313809  -0.01724125  0.19791707  0.00551758  0.11339158 -0.09302805
   0.23602076  0.68961358  0.01553136  0.36130285]
 [-0.00354887  0.0148001   0.16432436  0.01897741  0.22235572 -0.08562981
   0.14585365  0.75319165 -0.00278428  0.28128645]
 [-0.00527866 -0.00108556  0.19138697  0.01754982  0.15600391 -0.04693945

   0.02330696  0.91875976  0.12326637  0.33844444]]


版权声明:知识需要传播,如有需要,请任意转载https://blog.csdn.net/xiaokang06/article/details/80235950

你可能感兴趣的:(深度学习)