今天在回顾Seq2Seq利用Attention注意力机制实现的时候,发现decoder中用到的不是普通的LSTM而是LSTMCell,那么它到底什怎么回事?和LSTM又有哪些区别呢?以及在Seq2Seq中起到了什么作用?让我们一探究竟!
如图是一个RNN按时间步的展开图,RNNCell就相当于一个时间步的处理。
同理,LSTMCell是LSTM的一个单元,LSTMCell就相当于一个时间步的处理。
class LSTMCell(RNNCellBase):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
if hx is None:
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
hx = (zeros, zeros)
return _VF.lstm_cell(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)
和LSTM相比,LSTMCell参数中没有num_layers(层数)、bidirectional(双向)、dropout选项。
官方文档还提供一个用来同等实现LSTMCell作用的例子
Examples::
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
>>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
>>> hx = torch.randn(3, 20) # (batch, hidden_size)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(input.size()[0]):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
>>> output = torch.stack(output, dim=0)
可能类的源代码我们还一知半解,不太清楚到底和LSTM有什么区别,下面的这个例子就很好地说明了一切。
其实它就是LSTM操作,只不过每一次执行完LSTMCell后,我们都执行了一步,而有多少个时间步,我们就需要执行多少个LSTMCell。
未完待续~