LSTM:长短期记忆网络

LSTM:长短期记忆网络

  • 理论
    • 1.计算遗忘门
    • 2.计算输入门
    • 3.计算候选细胞状态
    • 4.更新细胞状态
    • 5.计算输出门
    • 6.计算输出隐状态
  • 实践
    • 从零实现LSTM
    • Pytorch实现LSTM
      • 参数
      • 输入
      • 输出

理论

LSTM:长短期记忆网络_第1张图片
LSTM:长短期记忆网络_第2张图片

LSTM的核心是细胞状态,用贯穿细胞的水平线表示
LSTM:长短期记忆网络_第3张图片

1.计算遗忘门

LSTM:长短期记忆网络_第4张图片
决定细胞状态需要舍弃哪部分无用信息

f t = σ g ( W f x t + U f h t − 1 + b f ) f_t = \sigma{_g} (W_f x_t+U_f h_{t-1}+b_f) ft=σg(Wfxt+Ufht1+bf)

2.计算输入门

LSTM:长短期记忆网络_第5张图片
决定细胞状态需要添加哪些有用信息

i t = σ g ( W i x t + U i h t − 1 + b i ) i_t = \sigma{_g}(W_i x_t+U_i h_{t-1}+b_i) it=σg(Wixt+Uiht1+bi)

3.计算候选细胞状态

c ~ t = σ c ( W c x t + U c h t − 1 + b c ) \widetilde{c}_t=\sigma{_c}(W_cx_t+U_ch_{t-1}+b_c) c t=σc(Wcxt+Ucht1+bc)

4.更新细胞状态

LSTM:长短期记忆网络_第6张图片

c t = f t ∘ c t − 1 + i t ∘ c ~ t c_t=f_t \circ c_{t-1}+i_t \circ \widetilde{c}_t ct=ftct1+itc t

5.计算输出门

控制细胞状态中哪些信息被输出
o t = σ g ( W o x t + U o h t − 1 + b o ) o_t=\sigma{_g}(W_ox_t+U_oh_{t-1}+b_o) ot=σg(Woxt+Uoht1+bo)

6.计算输出隐状态

h t = o t ∘ σ h ( c t ) h_t = o_t \circ \sigma{_h}(c_t) ht=otσh(ct)

实践

从零实现LSTM

class My_LSTM(nn. Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.gates = nn.Linear(input_size + hidden_size, hidden_size * 4)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn. Tanh()
        self.output = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, output_size)
        )
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def forward(self, x):
        batch_size = x.size(0)
        seq_len = x.size(1)
        h, c = (torch.zeros(batch_size, self.hidden_size).to(x.device) for _ in range(2))
        y_list = []
        for i in range(seq_len):
            forget_gate, input_gate, output_gate, candidate_cell = \
                self.gates(torch.cat([x[:, i, :], h], dim=-1)).chunk(4, -1)
            forget_gate, input_gate, output_gate = (self.sigmoid(g)
                                                    for g in (forget_gate, input_gate, output_gate))
            c = forget_gate * c + input_gate * self.tanh(candidate_cell)
            h = output_gate * self.tanh(c)
            y_list.append(self.output(h))
        return torch.stack(y_list, dim=1), (h, c)

Pytorch实现LSTM

参数

LSTM:长短期记忆网络_第7张图片

输入

LSTM:长短期记忆网络_第8张图片

输出

LSTM:长短期记忆网络_第9张图片

lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True).to(device)

你可能感兴趣的:(感知,融合与预测,lstm,深度学习,人工智能)