【大模型-flash attention安装】成功解决flash attention安装site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
本次修炼方法请往下查看
欢迎莅临我的个人主页 这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地!
相关内容文档获取 微信公众号
相关内容视频讲解 B站
博主简介:AI算法驯化师,混迹多个大厂搜索、推荐、广告、数据分析、数据挖掘岗位 个人申请专利40+,熟练掌握机器、深度学习等各类应用算法原理和项目实战经验。
技术专长: 在机器学习、搜索、广告、推荐、CV、NLP、多模态、数据分析等算法相关领域有丰富的项目实战经验。已累计为求职、科研、学习等需求提供近千次有偿|无偿定制化服务,助力多位小伙伴在学习、求职、工作上少走弯路、提高效率,近一年好评率100% 。
博客风采: 积极分享关于机器学习、深度学习、数据分析、NLP、PyTorch、Python、Linux、工作、项目总结相关的实用内容。
在进行大模型训练时,我们通过为了解决内存采用flash attention
策略来优化模型的性能,具体flash attention策略的原理可以自行去看论文,在实际的安装过程中坑太多了,最为经典的坑就是安装成功但是报各种各样的问题,最为经典的模型为目前最为火的模型为intervl
,其为了处理大量的图片会采用flash attention策略去优化内存,但是在搭配其环境时,里面有居多的坑,具体为如果安装好最新的pytorch就会报错,下面为最终的相关的配置,具体为:
torch 2.1.0
torchaudio 2.1.0
torchvision 0.16.0
flash-attn 2.5.6
transformers 4.37.2
在安装的过程中出现的问题,最开始pytorch最新版本为:
torch 2.2.0
看晚上很多的人说需要降低版本,因此,最后将版本降到2.1.0版本,至于高版本可不可以这个具体得看了,反正我调通了就没去试了。
很多大模型的官网说的直接安装,具体的命令如下所示:
pip install flash-attn==2.3.6 --no-build-isolation
上述的安装会成功,但是在导入的时候报如下的错误,具体为:
site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
根据很多的查找,感觉是这个版本太低了和pytorch版本不匹配的问题,因此,升级版本
从github上面 [https://github.com/Dao-AILab/flash-attention],安装上面的步骤进行安装,但是在安装的过程会遇到如下问题:
晚上说的问题是说是安装的过程中网络的问题,就在崩溃的时候,将其进行卸载,然后安装某个特定的版本就行了
将pytroch安装到特定的版本2.1.0,然后直接如下的命令:
pip install flash_attn==2.5.6
最终测试安装成功为,在conda环境中输入,结果为下面的就是安装成功了:
>>> from transformers import AutoModel, GenerationConfig, LlamaForCausalLM
>>>