PyTorch保存中间变量的导数值

在利用autograd机制时,一般只会保存函数值对输入的导数值,而中间变量的导数值都没有保留。例如:

x=torch.tensor(3., requires_grad=True)

x1=x**2

y=4*x1

y.backward()

查看导数值:

x.grad    # 输出24
x1.grad    # 没有输出

这时候可以利用register_hook方法进行操作,它需要一个函数作为参数,例如:

# 此处是第一处的计算代码,是为了生成计算图,这里省略了
def extract(g):
    global x1g    # 表示x1处的导数值
    x1g = g

x1.register_hook(extract)
y.backward()
x1g

这时候就会输出x1处的导数值4。

register_hook方法还可以用来打印中间变量导数值(print),或者修改中间导数值(如x1.register_hook(lambda g: g+1))。这都体现在他的参数(自定义的函数)里。

你可能感兴趣的:(python)