运行 pytorch gpu 时,报这个错
网上有很多人也遇到这个问题,有人说是 CUDA 和 cudnn 的版本匹配问题,有人说需要重装 Pytorch,CUDA,cudnn。我看了官网,版本是匹配的,试着重装了也不管用,而且我按照另一个系统的版本装也不行。
可以看到每次报错都在 conv.py 这个文件,就是在做 CNN 运算时出的错。
解决方法是引入如下语句
import torch
torch.backends.cudnn.enabled = False
这句话的意思是不用 cudnn 加速了。
GPU,CUDA,cudnn 的关系是:
参考:GPU,CUDA,cuDNN的理解
cudnn 默认会使用,既然目前解决不了匹配问题,就先不用了。这样 gpu 照样能工作,但可能没有用上 cudnn 那么快。
如果有朋友知道该怎么解决可能的版本问题,欢迎交流~
附: