参考文章:
What does Tensorflow LSTM return?
Tensorflow RNN LSTM output explanation
>>> inputs = tf.random.normal([32, 10, 8])
>>> lstm = tf.keras.layers.LSTM(4)
>>> output = lstm(inputs)
>>> print(output.shape)
(32, 4)
>>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
>>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
>>> print(whole_seq_output.shape)
(32, 10, 4)
>>> print(final_memory_state.shape)
(32, 4)
>>> print(final_carry_state.shape)
(32, 4)
其中图里上方的输出 h t h_t ht可以视为 o t o_t ot
在Keras中如果return_state=True
则LSTM单元有三个输出,分别为
在keras 文档中给出的写法如下:
whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
在文档中,他们不使用隐藏和单元状态这些术语。他们使用memory state表示短期记忆,即上面提到的隐藏状态。用carry state 通过所有LSTM单元,即上面提到的单元状态。
下面是前向传播的一部分源码
def step(cell_inputs, cell_states):
"""Step function that will be used by Keras RNN backend."""
h_tm1 = cell_states[0] #previous memory state
c_tm1 = cell_states[2] #previous carry state
z = backend.dot(cell_inputs, kernel)
z += backend.dot(h_tm1, recurrent_kernel)
z = backend.bias_add(z, bias)
z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)
i = nn.sigmoid(z0)
f = nn.sigmoid(z1)
c = f * c_tm1 + i * nn.tanh(z2)
o = nn.sigmoid(z3)
h = o * nn.tanh(c)
return h, [h, c]
从源码中可以看出,第一个和第二个输出是output/hidden state,第三个输出是cell state。并且从注释中可以看出,将hidden state 命名为 memory state ;将cell state 命名为 carry state。
当return_sequences=True
时,whole_seq_output是整个序列的输出,维度为(batch_size,seq_length,units)。
当return_sequences=False
时,whole_seq_output是最后一个单元的输出,维度为(batch_size,units),此时与第二个输出相同。