【已解决】运行GAN时Torch报错

目录

报错内容:

报错原因:

解决办法:

注意事项:


报错内容:

CPU下运行GAN代码报错

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

报错原因:

torch高于1.4.0以后,梯度更新变了。算G的loss的计算图是包含D的,但是在G的backward之前更新了D的值,这时候再去计算就不是和forward时候对应的梯度了。

解决办法:

修改torch和torchvision版本

# CUDA 10.0
conda install pytorch==1.0.0 torchvision==0.2.1 cuda100 -c pytorch

# CUDA 9.0
conda install pytorch==1.0.0 torchvision==0.2.1 cuda90 -c pytorch

# CUDA 8.0
conda install pytorch==1.0.0 torchvision==0.2.1 cuda80 -c pytorch

# CPU Only
conda install pytorch-cpu==1.0.0 torchvision-cpu==0.2.1 cpuonly -c pytorch

注意事项:

如果只改torch或者只改torchvision,结果是不对的,必须一起改。

你可能感兴趣的:(bug,Deep,learning,python,深度学习,人工智能,python,bug)