torch.autograd.grad()函数使用
import torch
x = torch.tensor([[0.,1.,2.,3.],[1.,2.,3.,4.],[2.,3.,4.,5.]]).requires_grad_(True)
'''
tensor([[0., 1., 2., 3.],
[1., 2., 3., 4.],
[2., 3., 4., 5.]], requires_grad=True)
'''
y = x ** 2 + x
'''
tensor([[ 0., 2., 6., 12.],
[ 2., 6., 12., 20.],
[ 6., 12., 20., 30.]], grad_fn=)
'''
weight = torch.ones(y.size()).requires_grad_(True)
'''
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], requires_grad=True)
'''
dydx1 = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)
'''
(tensor([[ 1., 3., 5., 7.],
[ 3., 5., 7., 9.],
[ 5., 7., 9., 11.]], grad_fn=),)
'''
dydx2 = torch.autograd.grad(outputs=dydx1,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)
'''
(tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]], grad_fn=),)
'''
z = torch.sum(x)
dzdx = torch.autograd.grad(outputs=z,
inputs=x,
retain_graph=True,
create_graph=True,
only_inputs=True)
print(dzdx)
'''
(tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]),)
'''