torch.nn.GRU详解代码

小白撸代码

import torch.nn as nn
import torch
from torch.autograd import Variable
# 输入是中 输入10个特征维度 隐藏是20个特征维度(输入是10列   输出是20列)一共是 2层
rnn = nn.GRU(10,20,2)
# print(rnn,"#####################")

# 输入 一个矩阵中含有5个矩阵  每个矩阵中是3行10列 10列是GRU格式中的10列
input = Variable(torch.randn(5,3,10))
#print(input,"++++++++++++++++++++++++")

#保存着batch中每个元素的初始化隐状态的Tensor
# 其中2是GRU中 公有的2层 3行要跟输入中的3行要相等的行数
h0 = Variable(torch.randn(2,3,20))
# print(h0,"@@@@@@@@@@@@@@@@@@@@@@")

#output保存RNN最后一层的输出的Tensor。
#hn保存着RNN最后一个时间步的隐状态。
output , hn = rnn(input,h0)
print(output)

你可能感兴趣的:(pytorch,pytorch)