我在深入浅出RNN一文中提过,RNN缺乏对信息的调控机制,无法有效利用已学到的信息(hidden state)。RNN机械地将每次学到的新知(hidden state)都揉进同一个的hidden state并将它随着循环传递下去,这样做虽然可以一直保留句子中每个token的信息,但是,当这种雨露均沾模式,遇到长句时,句首信息在hidden state中的占比就会很小,换句话说,它很容易会忘掉长句的句首甚至句中的信息。不仅如此,每个time的训练都会受到历史信息的影响(hidden state + input)。
LSTM和GRU的出现就是为了弥补RNN的这些缺陷。本文将会以重构LSTM和GRU的方式来剖析LSTM和GRU,点击【这里】可以查看完整源码。
LSTM
正如Figure 1所示,LSTM通过引入input gate、forget gate和output gate来调控input和hidden state。这里的gate是由sigmoid函数实现的,它可以将任意input转换成0~1的值,将这些值和hidden state进行element-wise相乘运算,就起到了前文所说的信息调控作用:0表示丢弃信息,1表示完整地保留信息,0.x表示按比例保留信息。
- forget gate用于调控句子信息对训练的影响,它可以用来丢弃、筛选无用的信息。
- input gate负责调控hidden state转换成记忆(cell state)的比例。
- output gate则是用来调控传递到下一个time的hidden state的比例。
- cell state是LSTM新增的tensor,负责记忆信息,它可以被forget gate清除。
强烈推荐illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation这篇博文,作者通过动图将各个gate的数据流清晰地呈现在读者面前,例如,forget gate:
nn.LSTM
class Model5(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(nv, wordvec_len)
self.input = nn.Linear(wordvec_len, nh)
self.rnn = nn.LSTM(nh, nh, 1, batch_first=True)
self.out = nn.Linear(nh, nv)
self.bn = BatchNorm1dFlat(nh)
self.h = torch.zeros(1, bs, nh).cuda()
self.c = torch.zeros(1, bs, nh).cuda()
def forward(self, x):
res, (h, c) = self.rnn(self.input(self.emb(x)), (self.h, self.c))
self.h = h.detach()
self.c = c.detach()
return self.out(self.bn(res))
通过Pytorch提供的nn.LSTM,可以轻易地构建出基于LSTM的RNN。
def lstm_loop(cell, x, h):
hx, cx = [], []
h, c = h
for o in x.transpose(0, 1): # time loop
h, c = cell(o, (h, c))
hx.append(h)
cx.append(c)
# reset shape: [batch, time, hidden size]
return [torch.stack(hx, dim=1), torch.stack(cx, dim=1)]
class Model6(Model5):
def __init__(self):
super().__init__()
self.h = torch.zeros(bs, nh).cuda()
self.c = torch.zeros(bs, nh).cuda()
self.cell = nn.LSTMCell(nh, nh)
def forward(self, x):
x = F.relu(self.input(self.emb(x)))
h, c = lstm_loop(self.cell, x, (self.h, self.c))
self.h = h[:, -1].detach()
self.c = c[:, -1].detach()
return self.out(self.bn(h))
RNN的工作原理是循环调用隐藏层来处理每个input(回看深入浅出RNN的Model1)。在这里,lstm_loop是循环体,nn.LSTMCell则是隐藏层。
在验证了lstm_loop的正确性后,接着就是自己动手写nn.LSTMCell。首先,需要先拿到LSTM的数学公式,它们可以在Pytorch的nn.LSTMCell源代码的注释中找到。在Jupyter Notebook中查看函数或类的源代码的方法很简单,只要在模块前加上“??”:
??nn.LSTMCell
上述数学公式中的、和分别是input gate、forget gate和output gate,和分别是要输出到下一个time的cell state和hidden state,是sigmoid,则是linear(x)。
class LSTMCell(nn.Module):
def __init__(self, nin, nh):
super().__init__()
self.lin_x = nn.Linear(nin, 4 * nh)
self.lin_h = nn.Linear(nh, 4 * nh)
def forward(self, x, hc):
h, c = hc
_x = self.lin_x(x)
_h = self.lin_h(h)
x_i, x_f, x_o, x_g = _x.chunk(4, dim=1)
h_i, h_f, h_o, h_g = _h.chunk(4, dim=1)
i = torch.sigmoid(x_i + h_i)
f = torch.sigmoid(x_f + h_f)
o = torch.sigmoid(x_o + h_o)
g = torch.tanh(x_g + h_g)
c_hat = f * c + i * g
h_hat = o * torch.tanh(c_hat)
return (h_hat, c_hat)
在LSTMCell中,将和的hidden sizes乘以4:self.lin_x = nn.Linear(nin, 4 * nh),再把它们等分成4份:x_i, x_f, x_o, x_g = _x.chunk(4, dim=1),这样就能将8个linear layer计算合并成2个来提升计算速度。
GRU
GRU的工作原理和LSTM很相似,也通过各种gate来调控信息,如Figure 1所示,reset gate和forget gate的功能相同,update gate则与input gate功能类似。和LSTM不同的是,GRU放弃了cell state,相应地,也就不需要output gate来生成hidden state,所以GRU的所需的计算量相比LSTM减少了1/4。模型效果类似,但计算得更快,这些特性让GRU越来越受到推崇。
按照重写LSTMCell的方法,很容易也可以重写GRUCell:
class GRUCell(nn.Module):
def __init__(self, nin, nh):
super().__init__()
self.lin_x = nn.Linear(nin, 3 * nh)
self.lin_h = nn.Linear(nh, 3 * nh)
def forward(self, x, h):
_x = self.lin_x(x)
_h = self.lin_h(h)
ir, iz, xin = _x.chunk(3, dim=1)
hr, hz, hn = _h.chunk(3, dim=1)
r = torch.sigmoid(ir + hr) # reset gate
z = torch.sigmoid(iz + hz) # update gate
n = torch.tanh(xin + r * hn) # new gate
h_hat = (1 - z) * n + (z * h)
return h_hat
END
本文通过重构LSTM和GRU的方式详解了这两个模型的工作原理,其中的关键就是理解它们的数学公式。很多人一谈到数学公式,就本能地排斥,希望通过文字、图示这些直觉上感觉更直观的方式来辅助编程,殊不知,很多时候根据数学公式来编程反而更简单。