[NLP] Description and implementation of LSTM neural network.

0. Statement

Today I intend to move from an intuitive understanding of LSTM to its implementation with PyTorch, and I believe readers can get substantial help through this blog.

1. What is the advantage of LSTM over RNN?

RNN can remember fewer words related to the context than LSTM, so RNN is often called a short-term neural network, while LSTM is a Long short-term neural network, where long means that it can remember more context than RNN which represents the short-term neural network. So you don’t have to be confused by the name(Long Short-term).

2. Differences in terminology between RNN and LSTM for picture representation.

The RNN and LSTM pictures are from here.
[NLP] Description and implementation of LSTM neural network._第1张图片
[NLP] Description and implementation of LSTM neural network._第2张图片

  1. The output in RNN is ot, while the output of LSTM is ht.
  2. The contextual information(Memory) in RNN is stored in ht(above), while the contextual information(Memory) in LSTM is stored in ct.

3. The composition and intuitive understanding of LSTM.

When I first saw the architecture diagram of LSTM, I noticed a schematic of sigmoid and multiplication together.[NLP] Description and implementation of LSTM neural network._第3张图片
The sigmoid takes values from 0 to 1, which means that certain numbers are multiplied by 0 or 1, which means that the significance of each of these structures is to decide whether to use the data from the previous time step.This structure is named “gate”.

3.1. LSTM: Forget gate.[NLP] Description and implementation of LSTM neural network._第4张图片

When the value of “ft” is 1, it means I want to use the data remembered in the previous time step, and when the value is 0, it means I want to forget it.

3.2. LSTM: Input gate and Cell State.[NLP] Description and implementation of LSTM neural network._第5张图片

When the value of “it” is 1, it means that I want to use the data entered at the current time (“C wave t”), which is calculated by “tanh”, “Wc” and “bc” based on the data entered at the current time “xt”.[NLP] Description and implementation of LSTM neural network._第6张图片
⚠️: In summary, the forgetting gate determines whether the information remembered at the previous time step is useful, and the inputting gate determines whether the information to be remembered at the current time step is important.

3.3. LSTM: Output.


In summary, the output “ht” of the LSTM is the element-wise product of the “tanh” operation of “ct” and the “output gate”.

3.4. Summary

[NLP] Description and implementation of LSTM neural network._第7张图片
The forgetting gate, input gate, and output gate require the sigmoid and the input x(t) at the current time and the context information h(t-1) from the previous time step. To input the c-wave into the cell requires tanh and x(t), h(t-1). The c(t) and h(t) to be passed to the next time step are easy to understand intuitively based on the above diagram.

[NLP] Description and implementation of LSTM neural network._第8张图片

3.5. Understanding the role of “gates” intuitively.[NLP] Description and implementation of LSTM neural network._第9张图片

3.6. Why can LSTM mitigate gradient vanishing?

[NLP] Description and implementation of LSTM neural network._第10张图片
Because when we solve for the gradient, we avoid the appearance of the kth power of “Whh”, and then because there are three “gates”, we need to expand three equations when solving for the gradient, and the three gates constrain each other so that the probability of a large or small value is much smaller.

4. How to implement LSTM with PyTorch?

[NLP] Description and implementation of LSTM neural network._第11张图片
The PyTorch documentation on LSTM is from here.

import torch
from torch import nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}\nh.shape:{h.shape}\nc.shape:{c.shape}")

⚠️: The out of the LSTM is the value of the last time step h of all time steps h.
[NLP] Description and implementation of LSTM neural network._第12张图片

import torch
from torch import nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}\nh.shape:{h.shape}\nc.shape:{c.shape}")

print('one layer lstm')
cell = nn.LSTMCell(input_size=100, hidden_size=20)
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
for xt in x:
    h, c = cell(xt, [h, c])
print(f"h.shape:{h.shape}")
print(f"c.shape:{c.shape}")

print("two layer lstm")
cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
h1 = torch.zeros(3, 30)
c1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
c2 = torch.zeros(3, 20)
for xt in x:
    h1, c1 = cell1(xt, [h1, c1])
    h2, c2 = cell2(h1, [h2, c2])
print(f"h2.shape:{h2.shape}")
print(f"c2.shape:{c2.shape}")

[NLP] Description and implementation of LSTM neural network._第13张图片

Finally

Thank you for the current age of knowledge sharing and the people willing to share it, thank you! The knowledge on this blog is what I’ve learned on this site, thanks for the support!

你可能感兴趣的:(Python,深度学习,Pytorch,自然语言处理,lstm,深度学习)