莫凡pytorch GAN 学习 bug记录1

bug1:

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 15]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

跟着莫凡老师学习GAN网络,写好代码却一直在报错,同时我也使用了retain_graph=True, 按理来说不应该报错了,看了很多解释说是版本问题,好像是inplace 设置问题,但是我不太懂。

求pytorch大神解答,问题出在哪里 - 虎扑社区

找到莫凡老师在github的源代码,发现可以运行,然后对比了一下,找到了一些原因,我修改了一下我的代码

图1, 初始代码

莫凡pytorch GAN 学习 bug记录1_第1张图片

 图2,修改之后的代码

莫凡pytorch GAN 学习 bug记录1_第2张图片

解决了这个问题,原因就是

G的输出作为 D的输入时,在反向计算D网络梯度时也会计算G网络的梯度

为了使两个计算图的梯度传递断开, 我们需要使用 D_input=G_output.detach()

 在计算 G的 反向传播的梯度时,就需要再一次计算G的输出和loss

你可能感兴趣的:(pytorch,bug,pytorch)