PyTorch - 打印梯度

x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
z = x + y
# this will work only in Python3
z.register_hook(lambda g: print(g)) 
# if you're using Python2 do this:
# def pring_grad(g):
#     print g
# z.register_hook(print_grad)
q = z.sum()
q.backward()

你可能感兴趣的:(pytorch)