使用 torch.nn.LSTM 可以方便的构建 LSTM,不熟悉 LSTM 的可以先看这两篇文章:
RNN:https://blog.csdn.net/yizhishuixiong/article/details/105588233
LSTM:https://blog.csdn.net/yizhishuixiong/article/details/105572296
torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False)
LSTM 的输入:input,(h_0,c_0)
LSTM 的输出:output,(h_n,c_n)
import torch
import torch.nn as nn
rnn = nn.LSTM(10, 20, 3) # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3
input = torch.randn(8, 3, 10) # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(3, 3, 20), torch.randn(3, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))
print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)
双向:
import torch
import torch.nn as nn
rnn = nn.LSTM(10, 20, 3, bidirectional=True) # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3,双向
input = torch.randn(8, 3, 10) # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(6, 3, 20), torch.randn(6, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))
print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)