torch.nn.RNN(input_size, hidden_size, num_layers)函数解析

torch.nn.RNN(input_size, hidden_size, num_layers)

pytorch官方文档链接:https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN

input_size:每个token作为输入时的向量长度
hidden_size:中间的隐层向量长度
num_layers:RNN模型的层数

以下对于batch_size=1举例

rnn = nn.RNN(10, 20, 2)
input = torch.randn(3, 10)
h0 = torch.randn(2, 20)
output, hn = rnn(input, h0)
# output.shape应该是(3,20);hn.shape应该是(2,20)

计算过程可根据下图理解。

官方文档中计算h_t的公式可根据手绘图中的“框1”理解。
在这里插入图片描述

你可能感兴趣的:(python,rnn)