PyTorch 中 LSTM 的 output、h_n 和 c_n 之间的关系

LSTM 简介

  • 官方文档:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
  • h_n:最后一个时间步的输出,即 h_n = output[:, -1, :](一般可以直接输入到后续的全连接层,在 Keras 中通过设置参数 return_sequences=False 获得)
  • c_n:最后一个时间步 LSTM cell 的状态(一般用不到)

实例

  • 实例:根据红框可以直观看出,h_n 是最后一个时间步的输出,即是 h_n = output[:, -1, :],如何还是无法直观理解,直接看如下截图,对照代码可以非常容易看出它们的关系
    PyTorch 中 LSTM 的 output、h_n 和 c_n 之间的关系_第1张图片

  • 实例代码:

>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.LSTM(input_size=2, hidden_size=3, batch_first=True)
>>> input = torch.randn(5,4,2)
>>> h0 = torch.randn(1, 5, 3)
>>> c0 = torch.randn(1, 5, 3)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[-0.1046, -0.0316, -0.2261],
         [ 0.0702,  0.0756, -0.2856],
         [ 0.1146,  0.0666, -0.1841],
         [ 0.1137,  0.0508, -0.3966]],

        [[ 0.3702, -0.1192, -0.3513],
         [ 0.3964, -0.0513, -0.1744],
         [ 0.3144,  0.0564, -0.2114],
         [ 0.3056,  0.1312, -0.1656]],

        [[ 0.1581, -0.3509,  0.0068],
         [ 0.2391, -0.0308,  0.0773],
         [ 0.2420,  0.0607, -0.0652],
         [ 0.2854,  0.0656, -0.0306]],

        [[-0.0562, -0.0229,  0.1600],
         [-0.2156, -0.0006,  0.0898],
         [ 0.0700,  0.2200, -0.0068],
         [ 0.1903,  0.3120,  0.0253]],

        [[ 0.1025, -0.0167,  0.3068],
         [ 0.2028,  0.0652,  0.1738],
         [ 0.3324,  0.1645,  0.1908],
         [ 0.2594,  0.0896, -0.0507]]], grad_fn=)
>>> hn
tensor([[[ 0.1137,  0.0508, -0.3966],
         [ 0.3056,  0.1312, -0.1656],
         [ 0.2854,  0.0656, -0.0306],
         [ 0.1903,  0.3120,  0.0253],
         [ 0.2594,  0.0896, -0.0507]]], grad_fn=)
>>> cn
tensor([[[ 0.3811,  0.2079, -0.7427],
         [ 0.9059,  0.2375, -0.3272],
         [ 0.5819,  0.1175, -0.0766],
         [ 0.5059,  0.5022,  0.0446],
         [ 0.7312,  0.2270, -0.0970]]], grad_fn=)
>>> output[-1]
tensor([[ 0.1025, -0.0167,  0.3068],
        [ 0.2028,  0.0652,  0.1738],
        [ 0.3324,  0.1645,  0.1908],
        [ 0.2594,  0.0896, -0.0507]], grad_fn=)
>>> output[:,:,-1]
tensor([[-0.2261, -0.2856, -0.1841, -0.3966],
        [-0.3513, -0.1744, -0.2114, -0.1656],
        [ 0.0068,  0.0773, -0.0652, -0.0306],
        [ 0.1600,  0.0898, -0.0068,  0.0253],
        [ 0.3068,  0.1738,  0.1908, -0.0507]], grad_fn=)
>>> output[:,-1,:]
tensor([[ 0.1137,  0.0508, -0.3966],
        [ 0.3056,  0.1312, -0.1656],
        [ 0.2854,  0.0656, -0.0306],
        [ 0.1903,  0.3120,  0.0253],
        [ 0.2594,  0.0896, -0.0507]], grad_fn=)
>>> output[:,-1,:].shape
torch.Size([5, 3])
>>> output.shape
torch.Size([5, 4, 3])
>>> hn.shape
torch.Size([1, 5, 3])
>>> cn.shape
torch.Size([1, 5, 3])

你可能感兴趣的:(PyTorch,基础)