pytorch报错详解:RuntimeError: Trying to backward through the graph a second time

代码如下,当我尝试对y3进行第二次梯度计算时,报了这个错

import torch
x = torch.tensor([1], dtype=torch.float32, requires_grad=True)
y1 = x ** 2
y2 = x ** 3
y3 = y1 + y2
y3.backward() 
print(x.grad)
x.grad.data.zero_()
y3.backward()
print(x.grad)

百度了一下问题原因和解决方法,
解决方法:在第一次backward中加一句retain_grad=True,意思为一直保留计算图,问题解决。

y3.backward(retain_graph=True)

报错原因就是pytorch的计算图在第一次执行完backward计算梯度的时候就已经被释放了。第二次想要再用计算图计算时,计算图已经没了,自然报错

那么,计算图是什么?
针对上述例子,计算图可以画成这样:
pytorch报错详解:RuntimeError: Trying to backward through the graph a second time_第1张图片
当第二次backward时,无法从y3回溯到y1/y2,这是报错的根源。

此外,我还尝试了另一个例子:

x = torch.randn(3,3,requires_grad=True)
print(x)
y = x + x
out2 = y.sum()
out2.backward()
print(x.grad)
x.grad.data.zero_()
out2.backward()
print(x.grad)

神奇的是,就算连续调用两次backward,依然不报错。当时我的猜想是这根据计算图的计算方法复杂程度而定,第一个例子变量都来自x,一个是二次方,一个是三次方,计算方法不同,计算图必须保留;第二个例子只用了加法一种计算方法,但随后我将这种想法推翻了,因为这个例子:

y = x*x
z = y*y
a = z*z
a = a.sum()
a.backward()
a.backward()

计算方法单一,但结果报错。我又将乘法都换成了加法,就没错了O_o

总结一下,只要是 只有加法存在 这种情况,不管套了多少层,backward()即可,剩下的情况都要backward(retain_graph=True)

你可能感兴趣的:(机器学习,pytorch,人工智能,经验分享,程序人生,debug)