pytorch报错:RuntimeError: Trying to backward through the graph a second time

在进行Gan的训练过程中,经常会遇到这个问题:RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed

这个错误的描述就是,你在进行方向传播的时候,缓存提前释放了,很多解决办法都是说在loss.backward() 改成loss.backward(retain_graph=True),实际上,大部分的代码问题不出在这里,而且保留计算图可能会使缓存快速积累,导致显存爆了,其实真正原因是由于判别器和生产器公用了相同的变量,只要在第一次用变量时候加上detach()就可以了。

#### Update Discriminator ###
real_preds = netD(real_gt)
fake_preds = netD(fake_coarse)


#### Update Generator #####
real_preds = netD(real_gt)
fake_preds = netD(fake_coarse)

改成:

#### Update Discriminator ###
real_preds = netD(real_gt.detach())
fake_preds = netD(fake_coarse.detach())


#### Update Generator #####
real_preds = netD(real_gt)
fake_preds = netD(fake_coarse)

查找方法不易,感谢支持

你可能感兴趣的:(深度学习成长之路,pytorch,python,深度学习)