pytorch中的梯度检查,gradcheck()函数

gradcheck():

gradcheck(func, inputs, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, check_sparse_nnz=False)
Check gradients computed via small finite differences against analytical
gradients w.r.t. tensors in :attr:inputs that are of floating point type
and with requires_grad=True.

The check between numerical and analytical gradients uses :func:`~torch.allclose`.

.. note::
    The default values are designed for :attr:`input` of double precision.
    This check will likely fail if :attr:`input` is of less precision, e.g.,
    ``FloatTensor``.

.. warning::
   If any checked tensor in :attr:`input` has overlapping memory, i.e.,
   different indices pointing to the same memory address (e.g., from
   :func:`torch.expand`), this check will likely fail because the numerical
   gradients computed by point perturbation at such indices will change
   values at all other indices that share the same memory address.

Args:
    func (function): a Python function that takes Tensor inputs and returns
        a Tensor or a tuple of Tensors
    inputs (tuple of Tensor or Tensor): inputs to the function
    eps (float, optional): perturbation for finite differences
    atol (float, optional): absolute tolerance
    rtol (float, optional): relative tolerance
    raise_exception (bool, optional): indicating whether to raise an exception if
        the check fails. The exception gives more information about the
        exact nature of the failure. This is helpful when debugging gradchecks.
    check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
        and for any SparseTensor at input, gradcheck will perform check at nnz positions only.

Returns:
    True if all differences satisfy allclose condition

例子

	from torch.autograd import gradcheck
	inputs = Variable(torch.randn(1,1,2,2), requires_grad=True)
	conv = nn.Conv2d(1,1,1,1)
	test = gradcheck(lambda x: conv(x),(inputs,))
	print(test)

输出:
numerical:tensor([[0.0596, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0447, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0596, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0447]])
analytical:tensor([[0.0483, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0483, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0483, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0483]])

你可能感兴趣的:(深度学习)