Anconda安装的pytorch依赖的cuda版本和系统cuda版本不一致问题

背景

  1. 使用Anaconda配置源码环境
  2. 源码需要使用python setup.py来编译依赖cuda的torch拓展模块,如 nms,ROIPool,ROIAlign等等
  3. 系统的CUDA和Conda装的cudatoolkit版本不同

问题

符合上述背景条件或者类似条件,会导致一些奇怪的错误,例如:

ImportError: ***/ATSS/atss_core/_C.cpython-36m-x86_64-linux-gnu.so: undefined symbol: __cudaRegisterFatBinaryEnd

原因分析

按照源码安装教程,conda安装pytorch时cudatoolkit版本为9.0,而编译源码时,有一个依赖CUDA的拓展模块atss_core._C,查看编译过程信息发现,使用的是系统默认CUDA,而非想象中的cudatoolkit包,这就导致拓展模块依赖的CUDA版本和Pytorch依赖的cuda版本不一,进而引发一些奇怪错误。

总结

在conda的虚拟环境中,编译依赖cuda的torch拓展模块时,并未优先使用cudatoolkit包,转而使用了系统CUDA包,如果两个CUDA包版本不一,可能会引发一些意想不到的问题。

推荐解决方案

一般出现这种问题时,源码作者可能指定了cudatoolkit,则不好去改动其版本。
推荐在系统上再安装一个对应CUDA版本,且不覆盖原有(系统默认)版本,然后使用CUDA_HOME环境变量来指出对应CUDA版本的安装目录,如 export CUDA_HOME=/usr/local/cuda-9.0,这样可以避免更改系统或用户文件。
通过这种方式,只要安装时不覆盖系统默认版本,不会影响其它用户。如果你仅在terminal中临时export,还能避免对其他依赖cuda程序的影响,很方便。

深入分析

按理说既然是在同一个虚拟环境下,编译涉及CUDA的拓展模块,应该优先搜索该虚拟环境下conda安装的cudatoolkit包,就能避免该问题。
进一步分析源码发现,cuda根目录由torch.utils.cpp_extension给出,而torch.utils.cpp_extension中的_find_cuda_home函数,只要没有设置CUDA_HOMECUDA_PATH环境变量,就会返回默认系统路径/usr/local/cuda

你可能感兴趣的:(PyTorch)