torch.nn.LSTM详细解读代码

小白代码解读

import torch.nn as nn
import torch
from torch.autograd import Variable
lstm =  nn.LSTM(10,20,2) # (1)输入的特征维度10列 (2)隐状态的特征维度20列 (3)num_layers = 2层
# print(lstm)
# print("***************************************************")
# 输入
input = Variable(torch.randn(5,3,10)) # 5行矩阵 每个矩阵是3行10列;10列是根据LSTM输入规定10列
#print(input)
# print("++++++++++++++++++++++++++++++++++++++++++++++++++")
# 保存着batch中每个元素的初始化隐状态的Tensor
h0 = Variable(torch.randn(2,3,20)) # 2行矩阵,每个矩阵是3行20列;2是LSTM中2层规定,20列是LSTM输入规定20列
# 保存着batch中每个元素的初始化细胞状态的Tensor
c0 = Variable(torch.randn(2,3,20))# 2行矩阵,每个矩阵是3行20列;2是LSTM中2层规定,20列是LSTM输入规定20列
# 输出
output , hn = lstm(input,(h0,c0))
# print(output)
# print("--------------------------------------------------")
print(hn)

你可能感兴趣的:(pytorch,pytorch)