PyTorch中有一个重要的机制就是自动求导机制。
如果需要记下一些中间变量的结果,或者是人为对导数做一些改变的话,就需要使用hook。
三类hook:
(1) torch.tensor(3).register_hook,针对tensor
(2) torch.nn.Module.register_forward_hook,针对nn.Module
(3) torch.nn.Module.register_backward_hook,针对nn.Module
该函数在PyTorch中的实现如下:
def register_hook(self, hook):
r"""Registers a backward hook.
The hook will be called every time a gradient with respect to the
Tensor is computed. The hook should have the following signature::
hook(grad) -> Tensor or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
Example::
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook
"""
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a tensor that "
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
说的是什么意思呢?register_hook是在反传的时候用的,它的参数是一个函数,函数的形式为:
hook(grad) -> Tensor or None
grad是这个tensor的梯度,返回是一个新的梯度值(可以代替原来的梯度值返回),但是它不应该改变原来的梯度值。同时这个hook也是可以移除的。它还会返回一个句柄handle,这个handle有一个remove()方法,可以用handle.remove()将这个hook移除,接下来举了一个例子说明register_hook怎么用。
在底层实现的时候,前面都是一些判断,重要的是注册了一个id,将这个hook和相应的tensor联系起来了。
该函数在PyTorch中的实现如下:
def register_forward_hook(self, hook):
r"""Registers a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:`forward` is called.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle
上面是什么意思呢?它是在module上注册一个forward_hook,每次调用forward计算输出的时候,该函数就会被调用,这个hook的形式为:
hook(module, input, output) -> None or modified output
该函数不应该修改input和output的值,返回一个句柄(handle),它还有一个方法handle.remove(),可以用handle.remove()将这个hook移除。
在代码计算model(X)的时候,底层先调用forward函数完成前向的操作,然后判断是否存在register_forward_hook(hook),如果有的话,就调用相应的hook完成一定的功能。
首先看一下底层的代码:
def register_backward_hook(self, hook):
r"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
module has multiple inputs or outputs. The hook should not modify its
arguments, but it can optionally return a new gradient with respect to
input that will be used in place of :attr:`grad_input` in subsequent
computations.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
.. warning ::
The current implementation will not have the presented behavior
for complex :class:`Module` that perform many operations.
In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
contain the gradients for a subset of the inputs and outputs.
For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
directly on a specific input or output to get the required gradients.
"""
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
和前向的相同,都是在module上注册一个backward_hook,每次调用backward计算输出的时候,该函数就会被调用,这个hook的形式为:
hook(module, grad_input, grad_output) -> Tensor or None
使用hook中的bug
pytorch中autograd以及hook函数详解