小白学Pytorch系列--Torch.nn API Recurrent Layers(8)

小白学Pytorch系列–Torch.nn API Recurrent Layers(8)

方法 注释
nn.RNNBase
nn.RNN 将具有tanh tanh或ReLU ReLU非线性的多层Elman RNN应用于输入序列。
nn.LSTM 将多层长短期记忆(LSTM) RNN应用于输入序列。
nn.GRU 将多层门控循环单元(GRU) RNN应用于输入序列。
nn.RNNCell 具有tanh或ReLU非线性的Elman RNN单元。
nn.LSTMCell 长短期记忆(LSTM)细胞。
nn.GRUCell 门控循环单元(GRU)细胞

nn.RNNBase

重置参数数据指针,以便它们可以使用更快的代码路径。

目前,只有当模块在GPU上并且启用了cuDNN时,这才有效。否则,这就是拒绝。
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第1张图片

nn.RNNCell

h ′ = tanh ⁡ ( W i h x + b i h + W h h h + b h h ) h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) h=tanh(Wihx+bih+Whhh+bhh)
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第2张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第3张图片

rnn = nn.RNNCell(10, 20)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
output = []
for i in range(6):
    hx = rnn(input[i], hx)
    output.append(hx)

nn.RNN

将具有tanh或ReLU非线性的多层Elman RNN应用于输入序列。对于输入序列中的每个元素,每个层计算以下函数:
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第4张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第5张图片

import torch.nn as nn
import torch
model = nn.RNN(input_size=10, hidden_size=100, batch_first=True, num_layers=2, bidirectional =True)

input_tensor = torch.randn(2, 5, 10 )
output, hidden = model(input_tensor)
print(output.shape) # [bz, seq_len, hz]
print(hidden.shape) #  [num_layer*D, bz, hz]

nn.LSTMCell

长短期记忆(LSTM) Cell
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第6张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第7张图片

>>> rnn = nn.LSTMCell(10, 20)  # (input_size, hidden_size)
>>> input = torch.randn(2, 3, 10)  # (time_steps, batch, input_size)
>>> hx = torch.randn(3, 20)  # (batch, hidden_size)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(input.size()[0]):
...     hx, cx = rnn(input[i], (hx, cx))
...     output.append(hx)
>>> output = torch.stack(output, dim=0)

nn.LSTM

将多层长短期记忆(LSTM) RNN应用于输入序列。
对于输入序列中的每个元素,每一层都计算以下函数
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第8张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第9张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第10张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第11张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第12张图片

>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0)) # [seq_len, bz, bi*hz]


nn.GRUCell

小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第13张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第14张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第15张图片

>>> rnn = nn.GRUCell(10, 20)
>>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
...     hx = rnn(input[i], hx)
...     output.append(hx)

nn.GRU

小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第16张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第17张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第18张图片
小白学Pytorch系列--Torch.nn API Recurrent Layers(8)_第19张图片

>>> rnn = nn.GRU(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)

你可能感兴趣的:(PyTorch框架,pytorch,深度学习,tensorflow)