Pytorch中backward(retain_graph=True)的 retain_graph参数解释

每次 backward() 时,默认会把整个计算图free掉。一般情况下是每次迭代,只需一次 forward() 和一次 backward() ,前向运算forward() 和反向传播backward()是成对存在的,一般一次backward()也是够用的。

 

但是不排除,由于自定义loss等的复杂性,需要一次forward(),多个不同loss的backward()来累积同一个网络的grad,来更新参数。于是,若在当前backward()后,不执行forward() 而是执行另一个backward(),需要在当前backward()时,指定保留计算图,backward(retain_graph)。

你可能感兴趣的:(深度学习,点滴)