import tensorflow as tf
tf.enable_eager_execution()
embedding = tf.Variable(tf.truncated_normal((2, 3, 4)))
lstm = tf.keras.layers.LSTM(units=5, return_sequences=False, return_state=False)
outputs = lstm(embedding) # return_sequences=False, return_state=False
print(outputs) # 只有每个样本的最后一个time step的输出, shape:(batch, hidden_units)
lstm = tf.keras.layers.LSTM(units=5, return_sequences=True, return_state=False)
outputs = lstm(embedding)
print(outputs) # (batch, seq_len, hidden_units)
lstm = tf.keras.layers.LSTM(units=5, return_sequences=True, return_state=True)
outputs, hidden, state = lstm(embedding)
"""
outputs: (batch, seq_len, hidden_units)
hidden: (batch, hidden_units), 是每个样本对应最后一个time step的输出, 这个输出对应着各自样本的最后一个time step
state: (batch, hidden_units), 是每个样本对应的记忆状态, 代表着cell state。
"""
print(outputs)
print(hidden)
print(state)
bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(5, return_sequences=False, return_state=False), merge_mode="concat")
outputs = bilstm(embedding) # shape: (batch, hidden_units * 2), 这个取决于merge_mode参数的设置
print(outputs)
"""
outputs: shape: (batch, hidden_units * 2), 这个取决于merge_mode参数的设置
"""
bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(5, return_sequences=True, return_state=False), merge_mode="concat")
outputs = bilstm(embedding)
print(outputs)
"""
outputs: shape=(batch,seq_len,hidden_units * 2), 最后一个维度值取决于merge_mode
"""
print("====")
bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(2, return_sequences=True, return_state=True), merge_mode="concat")
outputs = bilstm(embedding)
print(outputs)
"""
o, f_o, f_c, b_o, b_c = outputs
输出包括三部分:
1、layer output
2、 (h, c) for forward lstm
3、 (h, c) for backward lstm
"""