LSTMCell

文章目录

  • 什么是LSTMCell
      • LSTMCell含义
      • LSTMCell类
      • 为什么要用LSTMCell

什么是LSTMCell

今天在回顾Seq2Seq利用Attention注意力机制实现的时候,发现decoder中用到的不是普通的LSTM而是LSTMCell,那么它到底什怎么回事?和LSTM又有哪些区别呢?以及在Seq2Seq中起到了什么作用?让我们一探究竟!

LSTMCell含义

LSTMCell_第1张图片

如图是一个RNN按时间步的展开图,RNNCell就相当于一个时间步的处理。

同理,LSTMCell是LSTM的一个单元,LSTMCell就相当于一个时间步的处理。

LSTMCell类

class LSTMCell(RNNCellBase):
    def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)

    def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
        if hx is None:
            zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
            hx = (zeros, zeros)
        return _VF.lstm_cell(
            input, hx,
            self.weight_ih, self.weight_hh,
            self.bias_ih, self.bias_hh,
        )

和LSTM相比,LSTMCell参数中没有num_layers(层数)、bidirectional(双向)、dropout选项。

官方文档还提供一个用来同等实现LSTMCell作用的例子

Examples::

        >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
        >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
        >>> hx = torch.randn(3, 20) # (batch, hidden_size)
        >>> cx = torch.randn(3, 20)
        >>> output = []
        >>> for i in range(input.size()[0]):
                hx, cx = rnn(input[i], (hx, cx))
                output.append(hx)
        >>> output = torch.stack(output, dim=0)

可能类的源代码我们还一知半解,不太清楚到底和LSTM有什么区别,下面的这个例子就很好地说明了一切。

其实它就是LSTM操作,只不过每一次执行完LSTMCell后,我们都执行了一步,而有多少个时间步,我们就需要执行多少个LSTMCell。

为什么要用LSTMCell

未完待续~

你可能感兴趣的:(NLP模型学习,lstm,深度学习,python)