LSTM 简介
- 官方文档:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
- h_n:最后一个时间步的输出,即 h_n = output[:, -1, :](一般可以直接输入到后续的全连接层,在 Keras 中通过设置参数 return_sequences=False 获得)
- c_n:最后一个时间步 LSTM cell 的状态(一般用不到)
实例
>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.LSTM(input_size=2, hidden_size=3, batch_first=True)
>>> input = torch.randn(5,4,2)
>>> h0 = torch.randn(1, 5, 3)
>>> c0 = torch.randn(1, 5, 3)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[-0.1046, -0.0316, -0.2261],
[ 0.0702, 0.0756, -0.2856],
[ 0.1146, 0.0666, -0.1841],
[ 0.1137, 0.0508, -0.3966]],
[[ 0.3702, -0.1192, -0.3513],
[ 0.3964, -0.0513, -0.1744],
[ 0.3144, 0.0564, -0.2114],
[ 0.3056, 0.1312, -0.1656]],
[[ 0.1581, -0.3509, 0.0068],
[ 0.2391, -0.0308, 0.0773],
[ 0.2420, 0.0607, -0.0652],
[ 0.2854, 0.0656, -0.0306]],
[[-0.0562, -0.0229, 0.1600],
[-0.2156, -0.0006, 0.0898],
[ 0.0700, 0.2200, -0.0068],
[ 0.1903, 0.3120, 0.0253]],
[[ 0.1025, -0.0167, 0.3068],
[ 0.2028, 0.0652, 0.1738],
[ 0.3324, 0.1645, 0.1908],
[ 0.2594, 0.0896, -0.0507]]], grad_fn=)
>>> hn
tensor([[[ 0.1137, 0.0508, -0.3966],
[ 0.3056, 0.1312, -0.1656],
[ 0.2854, 0.0656, -0.0306],
[ 0.1903, 0.3120, 0.0253],
[ 0.2594, 0.0896, -0.0507]]], grad_fn=)
>>> cn
tensor([[[ 0.3811, 0.2079, -0.7427],
[ 0.9059, 0.2375, -0.3272],
[ 0.5819, 0.1175, -0.0766],
[ 0.5059, 0.5022, 0.0446],
[ 0.7312, 0.2270, -0.0970]]], grad_fn=)
>>> output[-1]
tensor([[ 0.1025, -0.0167, 0.3068],
[ 0.2028, 0.0652, 0.1738],
[ 0.3324, 0.1645, 0.1908],
[ 0.2594, 0.0896, -0.0507]], grad_fn=)
>>> output[:,:,-1]
tensor([[-0.2261, -0.2856, -0.1841, -0.3966],
[-0.3513, -0.1744, -0.2114, -0.1656],
[ 0.0068, 0.0773, -0.0652, -0.0306],
[ 0.1600, 0.0898, -0.0068, 0.0253],
[ 0.3068, 0.1738, 0.1908, -0.0507]], grad_fn=)
>>> output[:,-1,:]
tensor([[ 0.1137, 0.0508, -0.3966],
[ 0.3056, 0.1312, -0.1656],
[ 0.2854, 0.0656, -0.0306],
[ 0.1903, 0.3120, 0.0253],
[ 0.2594, 0.0896, -0.0507]], grad_fn=)
>>> output[:,-1,:].shape
torch.Size([5, 3])
>>> output.shape
torch.Size([5, 4, 3])
>>> hn.shape
torch.Size([1, 5, 3])
>>> cn.shape
torch.Size([1, 5, 3])