Pytorch 学习(8):Recurrent layers (循环层)实现之GRUCell
GRU是Gated Recurrent Unit,GRU是LSTM的一个变化形式。先看一个GRUCell的小例子
rnn = nn.GRUCell(10, 20)
input = Variable(torch.randn(6, 3, 10))
hx = Variable(torch.randn(3, 20))
output = []
for i in range(6):
hx = rnn(input[i], hx)
output.append(hx)
在rnn = nn.GRUCell(10, 20)代码中,GRUCell进行初始化,GRUCell的__init__方法调用父类class Module(object)的__init__方法:
class Module(object):
r"""Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call `.cuda()`, etc.
"""
dump_patches = False
r"""This allows better BC support for :meth:`load_state_dict`. In
:meth:`state_dict`, the version number will be saved as in the attribute
`_metadata` of the returned state dict, and thus pickled. `_metadata` is a
dictionary with keys follow the naming convention of state dict. See
``_load_from_state_dict`` on how to use this information in loading.
If new parameters/buffers are added/removed from a module, this number shall
be bumped, and the module's `_load_from_state_dict` method can compare the
version number and do appropriate changes if the state dict is from before
the change."""
_version = 1
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
......
继续子类class GRUCell(RNNCellBase):的__init__方法,初始化各参数:
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.reset_parameters()
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch.jit._tracing:
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
其中调用 result = self.forward(*input, **kwargs),调用 GRUCell的forward方法,获取前向传播结果:
class GRUCell(RNNCellBase):
......
def forward(self, input, hx=None):
self.check_forward_input(input)
if hx is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
self.check_forward_hidden(input, hx)
return self._backend.GRUCell(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)
GRUCell的forward方法返回self._backend.GRUCell(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
) ,调用GRUCell方法
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
gi = F.linear(input, w_ih)
gh = F.linear(hidden, w_hh)
state = fusedBackend.GRUFused.apply
return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh)
gi = F.linear(input, w_ih, b_ih)
gh = F.linear(hidden, w_hh, b_hh)
i_r, i_i, i_n = gi.chunk(3, 1)
h_r, h_i, h_n = gh.chunk(3, 1)
resetgate = torch.sigmoid(i_r + h_r)
inputgate = torch.sigmoid(i_i + h_i)
newgate = torch.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hidden - newgate)
return hy
通过resetgate、inputgate、newgate进行控制,最终计算出结果。
线性变换的源码如下:
def linear(input, weight, bias=None):
r"""
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
Shape:
- Input: :math:`(N, *, in\_features)` where `*` means any number of
additional dimensions
- Weight: :math:`(out\_features, in\_features)`
- Bias: :math:`(out\_features)`
- Output: :math:`(N, *, out\_features)`
"""
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
return torch.addmm(bias, input, weight.t())
output = input.matmul(weight.t())
if bias is not None:
output += bias
return output