PyTorch中的hook

文章目录

  • PyTorch中的hook
    • 针对Tensor的hook
    • 针对nn.Module的hook
      • register_forward_hook(hook)
      • register_backward_hook(hook)
  • 参考

PyTorch中的hook

PyTorch中有一个重要的机制就是自动求导机制。

如果需要记下一些中间变量的结果,或者是人为对导数做一些改变的话,就需要使用hook。

三类hook:

(1) torch.tensor(3).register_hook,针对tensor

(2) torch.nn.Module.register_forward_hook,针对nn.Module

(3) torch.nn.Module.register_backward_hook,针对nn.Module

针对Tensor的hook

该函数在PyTorch中的实现如下:

    def register_hook(self, hook):
        r"""Registers a backward hook.

        The hook will be called every time a gradient with respect to the
        Tensor is computed. The hook should have the following signature::

            hook(grad) -> Tensor or None


        The hook should not modify its argument, but it can optionally return
        a new gradient which will be used in place of :attr:`grad`.

        This function returns a handle with a method ``handle.remove()``
        that removes the hook from the module.

        Example::

            >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
            >>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
            >>> v.backward(torch.tensor([1., 2., 3.]))
            >>> v.grad

             2
             4
             6
            [torch.FloatTensor of size (3,)]

            >>> h.remove()  # removes the hook
        """
        if not self.requires_grad:
            raise RuntimeError("cannot register a hook on a tensor that "
                               "doesn't require gradient")
        if self._backward_hooks is None:
            self._backward_hooks = OrderedDict()
            if self.grad_fn is not None:
                self.grad_fn._register_hook_dict(self)
        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle

说的是什么意思呢?register_hook是在反传的时候用的,它的参数是一个函数,函数的形式为:

 hook(grad) -> Tensor or None

grad是这个tensor的梯度,返回是一个新的梯度值(可以代替原来的梯度值返回),但是它不应该改变原来的梯度值。同时这个hook也是可以移除的。它还会返回一个句柄handle,这个handle有一个remove()方法,可以用handle.remove()将这个hook移除,接下来举了一个例子说明register_hook怎么用。

在底层实现的时候,前面都是一些判断,重要的是注册了一个id,将这个hook和相应的tensor联系起来了。

针对nn.Module的hook

register_forward_hook(hook)

该函数在PyTorch中的实现如下:

    def register_forward_hook(self, hook):
        r"""Registers a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::

            hook(module, input, output) -> None or modified output

        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

上面是什么意思呢?它是在module上注册一个forward_hook,每次调用forward计算输出的时候,该函数就会被调用,这个hook的形式为:

hook(module, input, output) -> None or modified output

该函数不应该修改input和output的值,返回一个句柄(handle),它还有一个方法handle.remove(),可以用handle.remove()将这个hook移除。

在代码计算model(X)的时候,底层先调用forward函数完成前向的操作,然后判断是否存在register_forward_hook(hook),如果有的话,就调用相应的hook完成一定的功能。

register_backward_hook(hook)

首先看一下底层的代码:

    def register_backward_hook(self, hook):
        r"""Registers a backward hook on the module.

        The hook will be called every time the gradients with respect to module
        inputs are computed. The hook should have the following signature::

            hook(module, grad_input, grad_output) -> Tensor or None

        The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
        module has multiple inputs or outputs. The hook should not modify its
        arguments, but it can optionally return a new gradient with respect to
        input that will be used in place of :attr:`grad_input` in subsequent
        computations.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``

        .. warning ::

            The current implementation will not have the presented behavior
            for complex :class:`Module` that perform many operations.
            In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
            contain the gradients for a subset of the inputs and outputs.
            For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
            directly on a specific input or output to get the required gradients.

        """
        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle

和前向的相同,都是在module上注册一个backward_hook,每次调用backward计算输出的时候,该函数就会被调用,这个hook的形式为:

hook(module, grad_input, grad_output) -> Tensor or None

参考

使用hook中的bug

pytorch中autograd以及hook函数详解

你可能感兴趣的:(PyTorch)