RuntimeError: Trying to backward through the graph a second time but the buffers have already been f

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

遇到过几次这个报错了,这几天把backward和autograd给看了几遍,终于摸着点门道

首先要知道,为什么会报这个错,这个错翻译成白话文就是说:当我们第二次backward的时候,计算图的结构已经被破坏了(buffer的梯度被释放了),这也是pytorch动态图的机制,可以节省内存。

这里不仔细讲解backward是怎么工作的了,一般我们在训练的时候常用的就是loss.backward()这种写法,loss一般是一个标量

在pytorch的计算图中,其实只有两种元素:tensor和function,function就是加减乘除、开方、幂指对、三角函数等可求导运算,而tensor可细分为两类:叶子节点(leaf node)和非叶子节点。使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True。

看下面的这段代码

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第1张图片

这样执行完以后就会报开头的错

根据这段代码画出计算图,x是叶子节点,y和z都不是叶子节点

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第2张图片

当我们进行m.backward()之后,这个红框内的计算图就被破坏了,报错里说 the buffers have already been freed,在这张图上,很明显y不是叶子节点,不属于m.backward()里要计算的tensor,buffer指的就是y的梯度被释放掉了, m.backward()只会计算保留x的梯度

如果不想释放y的梯度怎么办,可以用m.backward(retain_graph=True)

buffer:反向传播中不需要被optimizer更新,区别于parameter

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第3张图片

进行n.backward()中,y是参与在计算图内的,但y的梯度已经被free掉了,所以报错

再看下一个例子

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第4张图片

这段代码是不会报错的

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第5张图片

虽然他们的计算图有相交之处(共享了x),x是叶子节点且requires_grad=True,经过backward后它的梯度是保留的,重要的是buffer(y和z)是不共享的,所以不会相互干扰

 

最后再来说下detach的作用,detach可以把一个非叶子节点变成叶子节点且requires_grad=False

看下图,蓝色箭头代表backward时的流向

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第6张图片

 

当我们把z 变成z.detach() ,流向就会变成这样

 

RuntimeError: Trying to backward through the graph a second time but the buffers have already been f_第7张图片

detach可以起到截流的作用

 

总结: 当报错RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

1.看看是不是自己的loss.backward里用到的参数里有重复用过的buffer

2.如果想把这个buffer作为叶子节点(不考虑之前的梯度)参与计算图,可以用detach的方法

 

参考:https://zhuanlan.zhihu.com/p/85506092 

你可能感兴趣的:(pytorch)