x = torch.arange(4.0) print("x = ",x) #x = tensor([0., 1., 2., 3.]) x.requires_grad_(True) # 等价于 `x = torch.arange(4.0, requires_grad=True)` print("x = ",x) #x = tensor([0., 1., 2., 3.], requires_grad=True) print("x.grad = ",x.grad) #x.grad = None 默认值是None y = 2 * torch.dot(x, x) print("y = ", y) #y = tensor(28., grad_fn=) y.backward() print("x.grad = ", x.grad) #x.grad = tensor([ 0., 4., 8., 12.]) # 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值 x.grad.zero_() y = x.sum() y.backward() print("x.grad = ", x.grad) #x.grad = tensor([1., 1., 1., 1.])
x.grad.zero_() y = x * x u = y.detach() z = u * x z.sum().backward() #Z对x的导数为u = x*x print("x.grad = ", x.grad) #x.grad = tensor([0., 1., 4., 9.]) print("u = ",u) #u = tensor([0., 1., 4., 9.]) x.grad.zero_() y = x * x z = y *x z.sum().backward() # Z的导数为3x*x*x print("x.grad = ", x.grad) #x.grad = tensor([ 0., 3., 12., 27.])
备注:detach()[source]
返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。
即使之后重新将它的requires_grad置为true,也不会具有梯度grad,这样就会继续使用这个新的Variable进行计算,后面当进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播。
控制流的梯度计算:即使构建函数的计算图需要通过 Python控制流(例如,条件、循环或任意函数调⽤),我们仍然可以计算得到的变量的梯度。
def f(a): b = a * 2 while b.norm() < 1000: print("b = ",b) b = b * 2 if b.sum() > 0: c = b print("c ==",c) else: c = 100 * b print("c = ",c) return c if __name__ == '__main__': a = torch.randn(size=(), requires_grad=True) print("a = ",a) d = f(a) d.backward() print("a.grad = ", a.grad) print("d/a = ", d/a) # 此函数类似f(a) = k*d ,故导数为:f(a)/d
输出结果如下:
D:\workspace\pyPrj\venv\Scripts\python.exe D:/workspace/pyPrj/main.py
a = tensor(-0.6704, requires_grad=True)
b = tensor(-1.3407, grad_fn=
b = tensor(-2.6814, grad_fn=
b = tensor(-5.3628, grad_fn=
b = tensor(-10.7257, grad_fn=
b = tensor(-21.4513, grad_fn=
b = tensor(-42.9027, grad_fn=
b = tensor(-85.8054, grad_fn=
b = tensor(-171.6107, grad_fn=
b = tensor(-343.2215, grad_fn=
b = tensor(-686.4429, grad_fn=
c = tensor(-137288.5938, grad_fn=
a.grad = tensor(204800.)
d/a = tensor(204800.0156, grad_fn=
Process finished with exit code 0