刚开始想直接从pytorch源码来整,结果你瞧瞧源码写的都是啥:
class GRU(RNNBase):
def __init__(self, *args, **kwargs):
super(GRU, self).__init__('GRU', *args, **kwargs)
@torch._jit_internal._overload_method # noqa: F811
def forward(self, input, hx=None): # noqa: F811
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
pass
@torch._jit_internal._overload_method # noqa: F811
def forward(self, input, hx=None): # noqa: F811
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
pass
def forward(self, input, hx=None): # noqa: F811
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
return output, self.permute_hidden(hidden, unsorted_indices)
https://discuss.pytorch.org/t/where-to-find-torch-c-variablefunctions-module/41305/5
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp 嗯对,源码是用C++写的,fine。那我就直接从写一个吧,反正也不难,好了正片开始了: