pytorch学习笔记 —— torch.nn.LSTM

使用 torch.nn.LSTM 可以方便的构建 LSTM,不熟悉 LSTM 的可以先看这两篇文章:

RNN:https://blog.csdn.net/yizhishuixiong/article/details/105588233

LSTM:https://blog.csdn.net/yizhishuixiong/article/details/105572296


下面详细讲述 torch.nn.LSTM 的使用

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False)

  • input_size:输入数据的大小;
  • hidden_size:隐藏层的大小(节点数量),输出向量的维度等于节点数量;
  • num_layers:recurrent layer 的数量(默认为1);
  • bias:默认为 True;
  • batch_first:输入输出维度的第一维是否为 batch_size。若为True,则 batch_size 在第一维,若为 False(默认),则 batch_size 在第二维;
  • dropout:若非0,则在除了最后一层的各层都使用 dropout 层,默认为0;
  • bidirectional:若为 True,则使用双向 LSTM,默认为 False;

LSTM 的输入:input,(h_0,c_0)

  • input:输入数据,shape 为(句子长度seq_len, 句子数量batch, 每个单词向量的长度input_size);
  • h_0:默认为0,shape 为(num_layers * num_directions单向为1双向为2, batch, 隐藏层节点数hidden_size);
  • c_0:默认为0,shape 为(num_layers * num_directions, batch, hidden_size)

LSTM 的输出:output,(h_n,c_n)

  • output:输出的 shape 为(seq_len, batch, num_directions * hidden_size);
  • h_n:shape 为(num_layers * num_directions, batch, hidden_size);
  • c_n:shape 为(num_layers * num_directions, batch, hidden_size);

代码演示

import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, 3)   # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3
input = torch.randn(8, 3, 10)   # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(3, 3, 20), torch.randn(3, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))

print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)

pytorch学习笔记 —— torch.nn.LSTM_第1张图片

双向:

import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, 3, bidirectional=True)   # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3,双向
input = torch.randn(8, 3, 10)   # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(6, 3, 20), torch.randn(6, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))

print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)

pytorch学习笔记 —— torch.nn.LSTM_第2张图片

你可能感兴趣的:(pytorch,自然语言处理)