6.7 门控循环单元(GRU)
上一节介绍了循环神经网络中的梯度计算方法。我们发现,当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题。通常由于这个原因,循环神经网络在实际中较难捕捉时间序列中时间步距离较大的依赖关系。
门控循环神经网络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较大的依赖关系。它通过可以学习的门来控制信息的流动。其中,门控循环单元(gated recurrent unit,GRU)是一种常用的门控循环神经网络 [1, 2]。另一种常用的门控循环神经网络则将在下一节中介绍。
6.7.1 门控循环单元
下面将介绍门控循环单元的设计。它引入了重置门(reset gate)和更新门(update gate)的概念,从而修改了循环神经网络中隐藏状态的计算方式。
6.7.1.1 重置门和更新门
如图6.4所示,门控循环单元中的重置门和更新门的输入均为当前时间步输入X t Xt \boldsymbol{X}_t H t − 1 ,输出由激活函数为sigmoid函数的全连接层计算得到。
图6.4 门控循环单元中重置门和更新门的计算
具体来说,假设隐藏单元个数为h h h [ − 1 , 1 ] 。
图6.5 门控循环单元中候选隐藏状态的计算
具体来说,时间步t t t H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t .
图6.6 门控循环单元中隐藏状态的计算
值得注意的是,更新门可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如图6.6所示。假设更新门在时间步t ′ t′ t' t 。这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。
我们对门控循环单元的设计稍作总结:
重置门有助于捕捉时间序列里短期的依赖关系;
更新门有助于捕捉时间序列里长期的依赖关系。
6.7.2 读取数据集
为了实现并展示门控循环单元,下面依然使用周杰伦歌词数据集来训练模型作词。这里除门控循环单元以外的实现已在6.2节(循环神经网络)中介绍过。以下为读取数据集部分。
import numpy as np
import torch
from torch import nn, optim
import torch. nn. functional as F
import sys sys. path. append( "…" ) import d2lzh_pytorch as d2l device = torch. device( ‘cuda’ if torch. cuda. is_available( ) else ‘cpu’ )
( corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l. load_data_jay_lyrics( ) Copy to clipboard Error Copied
6.7.3 从零开始实现
我们先介绍如何从零开始实现门控循环单元。
6.7.3.1 初始化模型参数
下面的代码对模型参数进行初始化。超参数num_hiddens
定义了隐藏单元的个数。
num_inputs, num_hiddens, num_outputs = vocab_size, 256 , vocab_size
print ( ‘will use’ , device)
def get_params ( ) : def _one ( shape) : ts = torch. tensor( np. random. normal( 0 , 0.01 , size= shape) , device= device, dtype= torch. float32) return torch. nn. Parameter( ts, requires_grad= True ) def _three ( ) : return ( _one( ( num_inputs, num_hiddens) ) , _one( ( num_hiddens, num_hiddens) ) , torch. nn. Parameter( torch. zeros( num_hiddens, device= device, dtype= torch. float32) , requires_grad= True ) )
W_xz, W_hz, b_z = _three( )
W_xr, W_hr, b_r = _three( )
W_xh, W_hh, b_h = _three( )
W_hq = _one( ( num_hiddens, num_outputs) )
b_q = torch. nn. Parameter( torch. zeros( num_outputs, device= device, dtype= torch. float32) , requires_grad= True )
return nn. ParameterList( [ W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] )
Copy to clipboard Error Copied 下面的代码定义隐藏状态初始化函数init_gru_state
。同6.4节(循环神经网络的从零开始实现)中定义的init_rnn_state
函数一样,它返回由一个形状为(批量大小, 隐藏单元个数)的值为0的Tensor
组成的元组。
def init_gru_state ( batch_size, num_hiddens, device) :
return ( torch. zeros( ( batch_size, num_hiddens) , device= device) , )
Copy to clipboard Error Copied 下面根据门控循环单元的计算表达式定义模型。
def gru ( inputs, state, params) :
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = [ ]
for X in inputs:
Z = torch. sigmoid( torch. matmul( X, W_xz) + torch. matmul( H, W_hz) + b_z)
R = torch. sigmoid( torch. matmul( X, W_xr) + torch. matmul( H, W_hr) + b_r)
H_tilda = torch. tanh( torch. matmul( X, W_xh) + torch. matmul( R * H, W_hh) + b_h)
H = Z * H + ( 1 - Z) * H_tilda
Y = torch. matmul( H, W_hq) + b_q
outputs. append( Y)
return outputs, ( H, )
Copy to clipboard Error Copied 我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。
num_epochs, num_steps, batch_size, lr, clipping_theta = 160 , 35 , 32 , 1e2 , 1e - 2
pred_period, pred_len, prefixes = 40 , 50 , [ ‘分开’ , ‘不分开’ ] Copy to clipboard Error Copied
我们每过40个迭代周期便根据当前训练的模型创作一段歌词。
d2l. train_and_predict_rnn( gru, get_params, init_gru_state, num_hiddens,
vocab_size, device, corpus_indices, idx_to_char,
char_to_idx, False , num_epochs, num_steps, lr,
clipping_theta, batch_size, pred_period, pred_len,
prefixes)
Copy to clipboard Error Copied
输出:
epoch 40, perplexity 149.477598, time 1.08 sec
分开 我不不你 我想你你的爱我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你
不分开 我想你你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 epoch 80, perplexity 31.689210, time 1.10 sec
分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
不分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不 epoch 120, perplexity 4.866115, time 1.08 sec
分开 我想要这样牵着你的手不放开 爱过 让我来的肩膀 一起好酒 你来了这节秋 后知后觉 我该好好生活 我
不分开 你已经不了我不要 我不要再想你 我不要再想你 我不要再想你 不知不觉 我跟了这节奏 后知后觉 又过 epoch 160, perplexity 1.442282, time 1.51 sec
分开 我一定好生忧 唱着歌 一直走 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的
不分开 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活Copy to clipboard Error Copied 6.7.4 简洁实现 在PyTorch中我们直接调用nn
模块中的GRU
类即可。
lr = 1e - 2
gru_layer = nn. GRU( input_size= vocab_size, hidden_size= num_hiddens)
model = d2l. RNNModel( gru_layer, vocab_size) . to( device)
d2l. train_and_predict_rnn_pytorch( model, num_hiddens, vocab_size, device,
corpus_indices, idx_to_char, char_to_idx,
num_epochs, num_steps, lr, clipping_theta,
batch_size, pred_period, pred_len, prefixes)
Copy to clipboard Error Copied 输出:
epoch 40, perplexity 1.022157, time 1.02 sec
分开手牵手 一步两步三步四步望著天 看星星 一颗两颗三颗四颗 连成线背著背默默许下心愿 看远方的星是否听
不分开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情走的太快就像龙卷风 不能承受我已无处 epoch 80, perplexity 1.014535, time 1.04 sec
分开始想像 爸和妈当年的模样 说著一口吴侬软语的姑娘缓缓走过外滩 消失的 旧时光 一九四三 在回忆 的路
不分开始爱像 不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好 epoch 120, perplexity 1.147843, time 1.04 sec
分开都靠我 你拿着球不投 又不会掩护我 选你这种队友 瞎透了我 说你说 分数怎么停留 所有回忆对着我进攻
不分开球我有多烦恼多 牧草有没有危险 一场梦 我面对我 甩开球我满腔的怒火 我想揍你已经很久 别想躲 说你 epoch 160, perplexity 1.018370, time 1.05 sec
分开爱上你 那场悲剧 是你完美演出的一场戏 宁愿心碎哭泣 再狠狠忘记 你爱过我的证据 让晶莹的泪滴 闪烁
不分开始 担心今天的你过得好不好 整个画面是你 想你想的睡不著 嘴嘟嘟那可爱的模样 还有在你身上香香的味道Copy to clipboard Error Copied 小结
门控循环神经网络可以更好地捕捉时间序列中时间步距离较大的依赖关系。
门控循环单元引入了门的概念,从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态。
重置门有助于捕捉时间序列里短期的依赖关系。
更新门有助于捕捉时间序列里长期的依赖关系。
参考文献 [1] Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: Encoder-decoder approaches. arXiv preprint arXiv:1409.1259.
[2] Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.
注:除代码外本节与原书此节基本相同,原书传送门