TensorFlow Keras LSTM 输出解释

参考文章:
What does Tensorflow LSTM return?
Tensorflow RNN LSTM output explanation

>>> inputs = tf.random.normal([32, 10, 8])
>>> lstm = tf.keras.layers.LSTM(4)
>>> output = lstm(inputs)
>>> print(output.shape)
(32, 4)
>>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
>>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
>>> print(whole_seq_output.shape)
(32, 10, 4)
>>> print(final_memory_state.shape)
(32, 4)
>>> print(final_carry_state.shape)
(32, 4)

TensorFlow Keras LSTM 输出解释_第1张图片
其中图里上方的输出 h t h_t ht可以视为 o t o_t ot

在Keras中如果return_state=True则LSTM单元有三个输出,分别为

  • 一个输出状态(output state) o t o_t ot
  • 一个隐藏状态(hidden state) h t h_t ht
  • 一个单元状态(cell state) c t c_t ct

在keras 文档中给出的写法如下:

whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)

在文档中,他们不使用隐藏和单元状态这些术语。他们使用memory state表示短期记忆,即上面提到的隐藏状态。用carry state 通过所有LSTM单元,即上面提到的单元状态。

下面是前向传播的一部分源码

def step(cell_inputs, cell_states):
    """Step function that will be used by Keras RNN backend."""
    h_tm1 = cell_states[0]   #previous memory state
    c_tm1 = cell_states[2]   #previous carry state

    z = backend.dot(cell_inputs, kernel)
    z += backend.dot(h_tm1, recurrent_kernel)
    z = backend.bias_add(z, bias)

    z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)

    i = nn.sigmoid(z0)
    f = nn.sigmoid(z1)
    c = f * c_tm1 + i * nn.tanh(z2)
    o = nn.sigmoid(z3)

    h = o * nn.tanh(c)
    return h, [h, c]

从源码中可以看出,第一个和第二个输出是output/hidden state,第三个输出是cell state。并且从注释中可以看出,将hidden state 命名为 memory state ;将cell state 命名为 carry state。

return_sequences=True时,whole_seq_output是整个序列的输出,维度为(batch_size,seq_length,units)。
return_sequences=False时,whole_seq_output是最后一个单元的输出,维度为(batch_size,units),此时与第二个输出相同。

你可能感兴趣的:(python,python,深度学习,tensorflow,keras,lstm)