gradcheck(func, inputs, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, check_sparse_nnz=False)
Check gradients computed via small finite differences against analytical
gradients w.r.t. tensors in :attr:inputs
that are of floating point type
and with requires_grad=True
.
The check between numerical and analytical gradients uses :func:`~torch.allclose`.
.. note::
The default values are designed for :attr:`input` of double precision.
This check will likely fail if :attr:`input` is of less precision, e.g.,
``FloatTensor``.
.. warning::
If any checked tensor in :attr:`input` has overlapping memory, i.e.,
different indices pointing to the same memory address (e.g., from
:func:`torch.expand`), this check will likely fail because the numerical
gradients computed by point perturbation at such indices will change
values at all other indices that share the same memory address.
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor or a tuple of Tensors
inputs (tuple of Tensor or Tensor): inputs to the function
eps (float, optional): perturbation for finite differences
atol (float, optional): absolute tolerance
rtol (float, optional): relative tolerance
raise_exception (bool, optional): indicating whether to raise an exception if
the check fails. The exception gives more information about the
exact nature of the failure. This is helpful when debugging gradchecks.
check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
and for any SparseTensor at input, gradcheck will perform check at nnz positions only.
Returns:
True if all differences satisfy allclose condition
from torch.autograd import gradcheck
inputs = Variable(torch.randn(1,1,2,2), requires_grad=True)
conv = nn.Conv2d(1,1,1,1)
test = gradcheck(lambda x: conv(x),(inputs,))
print(test)
输出:
numerical:tensor([[0.0596, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0447, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0596, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0447]])
analytical:tensor([[0.0483, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0483, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0483, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0483]])