【flash attention安装】成功解决flash attention安装: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6Tensor

【大模型-flash attention安装】成功解决flash attention安装site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
本次修炼方法请往下查看
在这里插入图片描述

欢迎莅临我的个人主页 这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地!
相关内容文档获取 微信公众号
相关内容视频讲解 B站

博主简介:AI算法驯化师,混迹多个大厂搜索、推荐、广告、数据分析、数据挖掘岗位 个人申请专利40+,熟练掌握机器、深度学习等各类应用算法原理和项目实战经验

技术专长: 在机器学习、搜索、广告、推荐、CV、NLP、多模态、数据分析等算法相关领域有丰富的项目实战经验。已累计为求职、科研、学习等需求提供近千次有偿|无偿定制化服务,助力多位小伙伴在学习、求职、工作上少走弯路、提高效率,近一年好评率100%

博客风采: 积极分享关于机器学习、深度学习、数据分析、NLP、PyTorch、Python、Linux、工作、项目总结相关的实用内容。

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

文章目录

  • 1. 问题介绍
  • 2. flash-attn解决问题
    • 2.1 直接pip安装flash-attn
    • 2.2 通过源码安装
    • 2.3 最终解决方案

下滑查看解决方法

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

  

1. 问题介绍

  在进行大模型训练时,我们通过为了解决内存采用flash attention策略来优化模型的性能,具体flash attention策略的原理可以自行去看论文,在实际的安装过程中坑太多了,最为经典的坑就是安装成功但是报各种各样的问题,最为经典的模型为目前最为火的模型为intervl,其为了处理大量的图片会采用flash attention策略去优化内存,但是在搭配其环境时,里面有居多的坑,具体为如果安装好最新的pytorch就会报错,下面为最终的相关的配置,具体为:

torch                  2.1.0
torchaudio             2.1.0
torchvision            0.16.0
flash-attn             2.5.6
transformers           4.37.2

  在安装的过程中出现的问题,最开始pytorch最新版本为:


torch                  2.2.0

  看晚上很多的人说需要降低版本,因此,最后将版本降到2.1.0版本,至于高版本可不可以这个具体得看了,反正我调通了就没去试了。

2. flash-attn解决问题

2.1 直接pip安装flash-attn

  很多大模型的官网说的直接安装,具体的命令如下所示:

pip install flash-attn==2.3.6 --no-build-isolation

  上述的安装会成功,但是在导入的时候报如下的错误,具体为:

site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE

  根据很多的查找,感觉是这个版本太低了和pytorch版本不匹配的问题,因此,升级版本

2.2 通过源码安装

  从github上面 [https://github.com/Dao-AILab/flash-attention],安装上面的步骤进行安装,但是在安装的过程会遇到如下问题:
【flash attention安装】成功解决flash attention安装: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6Tensor_第1张图片
  晚上说的问题是说是安装的过程中网络的问题,就在崩溃的时候,将其进行卸载,然后安装某个特定的版本就行了

2.3 最终解决方案

  将pytroch安装到特定的版本2.1.0,然后直接如下的命令:

pip install flash_attn==2.5.6

  最终测试安装成功为,在conda环境中输入,结果为下面的就是安装成功了:

>>> from transformers import AutoModel, GenerationConfig, LlamaForCausalLM
>>> 

你可能感兴趣的:(AIGC,flash_attntion,大模型,intervl)