pytorch 手写GRU

刚开始想直接从pytorch源码来整,结果你瞧瞧源码写的都是啥:

class GRU(RNNBase):
    def __init__(self, *args, **kwargs):
        super(GRU, self).__init__('GRU', *args, **kwargs)

    @torch._jit_internal._overload_method  # noqa: F811
    def forward(self, input, hx=None):  # noqa: F811
        # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
        pass

    @torch._jit_internal._overload_method  # noqa: F811
    def forward(self, input, hx=None):  # noqa: F811
        # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
        pass

    def forward(self, input, hx=None):  # noqa: F811
        orig_input = input
        # xxx: isinstance check needs to be in conditional for TorchScript to compile
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
            sorted_indices = None
            unsorted_indices = None

        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = torch.zeros(self.num_layers * num_directions,
                             max_batch_size, self.hidden_size,
                             dtype=input.dtype, device=input.device)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        if batch_sizes is None:
            result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
                             self.dropout, self.training, self.bidirectional, self.batch_first)
        else:
            result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
                             self.num_layers, self.dropout, self.training, self.bidirectional)
        output = result[0]
        hidden = result[1]

        # xxx: isinstance check needs to be in conditional for TorchScript to compile
        if isinstance(orig_input, PackedSequence):
            output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
            return output_packed, self.permute_hidden(hidden, unsorted_indices)
        else:
            return output, self.permute_hidden(hidden, unsorted_indices)

https://discuss.pytorch.org/t/where-to-find-torch-c-variablefunctions-module/41305/5

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp 嗯对,源码是用C++写的,fine。那我就直接从写一个吧,反正也不难,好了正片开始了:

你可能感兴趣的:(NLP)