pytorch——LSTM原理与实现

文章目录

  • RNN训练难题
    • 梯度爆炸
    • 梯度弥散
  • LSTM
    • 遗忘门
    • 输入门
    • 输出门
    • LSTM 总结
  • Pytorch实现LSTM的方法
    • nn.LSTM()
    • nn.LSTMCell()

RNN训练难题

RNN的梯度推导公式:
pytorch——LSTM原理与实现_第1张图片
累乘会导致的梯度爆炸或梯度弥散。

梯度爆炸

现象:比如loss从0.25、0.24突然变的很大,比如1.7、2.3。

解决方案:对梯度做clipping(保持梯度的方向,将梯度的模变小)。

pytorch——LSTM原理与实现_第2张图片
pytorch——LSTM原理与实现_第3张图片
将gradient的模clipping到0-10的范围内,之后再做optimizer.step()效果就会好很多。

梯度弥散

pytorch——LSTM原理与实现_第4张图片
反向传播时越靠前的神经层更新越小,前面的神经层的梯度会接近于0,得到的更新会非常小。

解决梯度弥散:LSTM

LSTM

相比于RNN,LSTM可以记住更长时间的语境。

记忆Ct-1经过乘运算后的范围:0 ~ Ct-1。

遗忘门

pytorch——LSTM原理与实现_第5张图片
f(t)由h(t-1)和x(t)决定,控制着t时刻之前信息的保留量。

输入门

pytorch——LSTM原理与实现_第6张图片
i(t)是门的开度,表示当前信息保留多少与过去的信息融合。
新的信息不是x(t),而是x(t)运算后得到的C~(t)。

pytorch——LSTM原理与实现_第7张图片
C(t)是新的记忆。

输出门

pytorch——LSTM原理与实现_第8张图片
h(t)是输出。
o(t)表示输出门的开度,范围0-1。当前记忆C(t)不一定全部输出,C(t)经过tanh,与o(t)相乘后,可以有选择地输出。

LSTM 总结

三个门的开度都是由h(t-1)和X(t)控制的。
pytorch——LSTM原理与实现_第9张图片
pytorch——LSTM原理与实现_第10张图片

LSTM如何解决梯度弥散:
LSTM避免了梯度的累乘,变成了四项累加。
pytorch——LSTM原理与实现_第11张图片

Pytorch实现LSTM的方法

nn.LSTM()

pytorch——LSTM原理与实现_第12张图片pytorch——LSTM原理与实现_第13张图片

import torch
import torch.nn as nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)  # LSTM(100, 20, num_layers=4)

x = torch.randn(10,3,100)
# x.shape: [seq_len, batch_size, input_size]

out, (h, c) = lstm(x)
# out.shape: [seq_len, batch_size, hidden_size]
# h.shape: [num_layers, batch_size, hidden_size]
# c.shape: [num_layers, batch_size, hidden_size]

print(out.shape)  # torch.Size([10, 3, 20])
print(h.shape)  # torch.Size([4, 3, 20])
print(c.shape)  # torch.Size([4, 3, 20])

nn.LSTMCell()

pytorch——LSTM原理与实现_第14张图片

pytorch——LSTM原理与实现_第15张图片

import torch
import torch.nn as nn

cell = nn.LSTMCell(input_size=100, hidden_size=20)

x = torch.randn(10,3,100)
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)

for xt in x:
    h, c = cell(xt, [h, c])
    # xt: [batch_size, input_size]
    # h: [batch_size, hidden_size]
    # c: [batch_size, hidden_size]
    
print(h.shape)  # torch.Size([3, 20])
print(c.shape)  # torch.Size([3, 20])

你可能感兴趣的:(机器学习)