对于输入序列中的每个元素,每层计算以下函数:
i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1}+b_{hi}) it=σ(Wiixt+bii+Whiht−1+bhi)
f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1}+b_{hf}) ft=σ(Wifxt+bif+Whfht−1+bhf)
o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1}+b_{ho}) ot=σ(Wioxt+bio+Whoht−1+bho)
g t = t a n h ( W i g x t + b i g + W h g h t − 1 + b h g ) g_t=tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) gt=tanh(Wigxt+big+Whght−1+bhg)
c t = f t ⊙ c t − 1 + i t ⊙ g t c_t=f_t \odot c_{t-1} + i_t \odot g_t ct=ft⊙ct−1+it⊙gt
h t = o t ⊙ t a n h ( c t ) h_t = o_t \odot tanh(c_t) ht=ot⊙tanh(ct)
其中各个变量的含义如下:
import torch.nn as nn
import torch
rnn = nn.LSTM(10, 20, 2)# embedding_size, hidden_size, num_layer
input = torch.randn(5, 3, 10)# sequence length, batch size, embedding_size
h0 = torch.randn(2, 3, 20)# num_layer*dirc, batch size, hidden_size
c0 = torch.randn(2, 3, 20)# num_layer*dirc, batch size, hidden_size
output, (hn, cn) = rnn(input, (h0, c0))
output.shape
Out[8]: torch.Size([5, 3, 20])# # sequence length, batch size, hidden_size
hn.shape
Out[9]: torch.Size([2, 3, 20])# num_layer*dirc, batch size, hidden_size
c0.shape
Out[10]: torch.Size([2, 3, 20])# num_layer*dirc, batch size, hidden_size
rnn = nn.LSTM(input_size=1, hidden_size=20, num_layers=2)
input = torch.tensor([[1,2,0], [3,0,0], [4,5,6]], dtype=torch.float)
lens = [2, 1, 3]
# 构建输入数据,维度为:torch.Size([3, 3, 1]), 即 bactch_size=3, sequence length=3, embedding size=1
input = input.unsqueeze(2)
input
Out[68]:
tensor([[[1.],
[2.],
[0.]],
[[3.],
[0.],
[0.]],
[[4.],
[5.],
[6.]]])
# 第一维是 batch,则batch_first=True,
padded_seq = pack_padded_sequence(input, lens, batch_first=True, enforce_sorted=False)
# 将 padded_seq输入,并且不对hidden和cell进行初始化
output, (hn, cn) = rnn(padded_seq)
# 进行逆操作拆箱
output = pad_packed_sequence(output, batch_first=True)
# output[0] LSTM输出,output[1]为batch中样本长度
output[0].shape
Out[72]: torch.Size([3, 3, 20])
output[1]
Out[73]: tensor([2, 1, 3])
hn.shape
Out[76]: torch.Size([2, 3, 20])
cn.shape
Out[77]: torch.Size([2, 3, 20])
未完待续。。。