LSTM GRU tensorflow代码 和 原理图中的箭头 的对应关系

LSTM GRU tensorflow代码 和 原理图中的箭头 的对应关系_第1张图片
上图为LSTMcell,
向上指的箭头h_t是output,
向右指的箭头h_tC_t是state,

LSTM GRU tensorflow代码 和 原理图中的箭头 的对应关系_第2张图片
对于上图的GRU
output和state是同一个信息

import tensorflow as tf
from tensorflow.contrib import rnn

x = tf.constant([[1]], dtype = tf.float32)

lstm_cell = rnn.BasicLSTMCell(2)
gru_cell = rnn.GRUCell(2)

state0_lstm = lstm_cell.zero_state(1,dtype=tf.float32)
output,state = lstm_cell(x,state0_lstm)

state0_gru = gru_cell.zero_state(1,dtype=tf.float32)
output2,state2 = gru_cell(x,state0_gru)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(output))
    print(sess.run(state))

    print(sess.run(output2))
    print(sess.run(state2))

结果:
[[ 0.05328973 -0.07351915]]
LSTMStateTuple(c=array([[ 0.11276773, -0.12946127]], dtype=float32), h=array([[ 0.05328973, -0.07351915]], dtype=float32))

[[ 0.30531788 -0.00426328]]
[[ 0.30531788 -0.00426328]]

你可能感兴趣的:(TensorFlow)