PyTorch-LSTM

 1 import torch
 2 import torch.nn as nn
 3 
 4 torch.random.manual_seed(10)
 5 
 6 input_size = 2  # 输入向量维度
 7 hidden_size = 4 # 隐层层维度
 8 num_layers = 2 # 层数
 9 
10 lstm = nn.LSTM(input_size, hidden_size, num_layers)
11 
12 
13 # Input:
14 
15 # input of shape (sep_len, bath, input_size)
16 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size)
17 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size)
18 
19 # Output:
20 # output of shape (sep_len, bath, num_directions * hidden_size)
21 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size)
22 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size)
23 
24 # two ways
25 Input = torch.randn(4, 3, 2)
26 h = torch.randn(2, 3, 4)
27 c = torch.randn(2, 3, 4)
28 output = None
29 
30 # first
31 h1 = h
32 c1 = c
33 for it in Input:
34     output, (h1, c1) = lstm(it.view(1, 3, -1), (h1, c1))
35     print((output == h1[-1]).all().item())
36 print(output)
37 
38 # second
39 output1, (h, c) = lstm(Input,(h, c))
40 print(output1[-1])
41 # print(output1[-1] == output) 精度的问题

你可能感兴趣的:(PyTorch-LSTM)