LSTM参数详解(其余RNN类似)

输入数据 input: (seq_len, batch_size, input_size)
LSTM(input_size, hidden_size, num_layers = 1, bidirectional = False)
其中在时间步 t 的hidden_state ht 和cell_state ct 的shape均为
(num_layers * num_direction, batch_size,hidden_size)
输出向量 output: (seq_leng, batch_size, num_directions * hidden_size)

调用方法 output,(hn,cn) = lstm(input,(h0,c0)) #h0,c0如果省略即为0向量

这里有一点需要说明

If num_layers = num_directions = 1:

output.size() == (seq_len, batch_size, hidden_size)
hn.size() == (1, batch_size, hidden_size)
hn 就是 output 的seq_len维度最后一个index的元素。

If num_layers = 1 && num_directions = 2:

output.size() == (seq_len, batch_size, 2 * hidden_size)
hn.size() == (2, batch_size, hidden_size)
那么这时 output在seq_len维度最后一个index的元素其实就是 hn[0]hn[1]的concatenation
其中 hn[0]LSTM从左向右编码句子的最后一个hidden_state,对应最后一个token;
然而 hn[1]LSTM从右向左编码句子的最后一个hidden_state,对应第一个token。
output 如果按照第三维度均分为两份就可以得到 output_forwardoutput_backward
他们的size() 都 == (seq_len, batch_size, hidden_size)
其中
output_forward[-1] == hn[0] 也就是从左向右编码对应最后一个token的hidden_state ;
output_backward[0] == hn[1] 也就是从右向左编码对应第一个token的hidden_state ;

你可能感兴趣的:(LSTM参数详解(其余RNN类似))