pytorch中的梯度问题

1. 输入要梯度,输出必须要梯度

我们只能指定计算图的leaf节点的requires_grad变量来决定改变量是否记录梯度,而不能指定它们运算产生的节点的requires_grad,它们是否要梯度取决于它们的输入节点,它们的输入节点只要有一个requires_grad是True,那么它的requires_grad也是True.

x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
2w.requires_grad = True
y = x @ w.t()
z = y @ w2.t()
print(y.requires_grad, z.requires_grad)

z.sum().backward()

2. 获得中间节点的梯度

对于叶节点,如果我们指定了梯度,我们可以调用v.grad查看梯度;但是对于中间变量v.grad永远是None,如果要获得其梯度,就要使用register_hook,它会在调用这个变量的梯度反传的时候调用注册的函数.以下是一个简单的查看版本

import torch
from torch import nn

def hook(grad):
	print(grad)

x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
w2.requires_grad = True
y = x @ w.t()
z = y @ w2.t()

y.register_hook(hook)
z.sum().backward()  # invoke get_grad('y') here

改进版

import torch


class GradCollector(object):
    def __init__(self):
        self.grads = {}

    def __call__(self, name: str):
        def hook(grad):
            self.grads[name] = grad
        return hook
    

x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
w2.requires_grad = True
y = x @ w.t()
z = y @ w2.t()

grad_collector = GradCollector()
y.register_hook(grad_collector("y"))
z.register_hook(grad_collector('z'))

z.sum().backward()

print(grad_collector.grads['y'])
print(grad_collector.grads['z'])

你可能感兴趣的:(pytorch)