Pytorch :Trying to backward through the graph a second time, but the buffers have already been freed

最近在学习Pytorch,刚用Pytorch重写了之前用Tensorlfow写的论文代码。
首次运行就碰到了一个bug:
Pytorch - RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
刚开始按照这个错误提示,设置loss.backward(retain_graph=True),虽然解决了这个问题,但是随着训练的继续,报错OOM。很尴尬。。。
查了stackoverflow上的方法,最终解决了问题。

我原来的代码是:

     for side in outputs:
         loss += Loss(side, label)

     loss.backward(retain_graph=True)

很显然,一旦调用loss.backward(), 就相当于调用了多次的Loss(side, label).backward()方法,而Pytorch的机制是每次调用.backward()都会free掉所有buffers,所以它提示,让retain_graph。然而当retain后,buffers就不会被free了,所以会OOM。
最后的解决办法就是, 分开写:

    side0 = Loss(output[0], label)
    side1 = Loss(output[1], label)
    side2 = Loss(output[2], label)
    side3 = Loss(output[3], label)
    side4 = Loss(output[4], label)
    side5 = Loss(output[5], label)
    loss = side0 + side1 + side2 + side3 + side4 + side5

你可能感兴趣的:(Pytorch学习笔记)