pytorch0.4安装ctc_loss

文章目录

      • 安装流程
      • 注意:
      • 常见问题

前言: pytorch0.4.1的安装可以参考我的另外一篇博客pytorch0.4.1安装CTC loss
pytorch1.0后框架自带有ctc损失函数

安装流程

  1. 克隆项目,在根目录下新建build文件夹
git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
  1. 对原文件:pytorch_binding/src/binding.cpp进行2处修改:
1 at 92 lines
int probs_size = THCudaTensor_size(state, probs, 2);
2 at l05 lines
void* gpu_workspace;
THCudaMalloc(state, &gpu_workspace, gpu_size_bytes);
  1. 编译安装
cd pytorch_binding
python setup.py install
  1. 测试
cd pytorch_binding/tests/
python test_cpu.py
python test_gpu.py

>>>
=========================================================================================== test session starts ============================================================================================
platform linux -- Python 3.6.7, pytest-4.0.0, py-1.8.0, pluggy-0.9.0
rootdir: /home/chenjun/warp-ctc/pytorch_binding, inifile: setup.cfg
collected 4 items                                                                                                                                                                                          

test_cpu.py ....                                                                                                                                                                                     [100%]

========================================================================================= 4 passed in 0.10 seconds =========================================================================================
(torch04) chenjun@chenjun-MS-7A71:~/warp-ctc/pytorch_binding/tests$ python test_gpu.py 
=========================================================================================== test session starts ============================================================================================
platform linux -- Python 3.6.7, pytest-4.0.0, py-1.8.0, pluggy-0.9.0
rootdir: /home/chenjun/warp-ctc/pytorch_binding, inifile: setup.cfg
collected 4 items                                                                                                                                                                                          

test_gpu.py ....                                                                                                                                                                                     [100%]

========================================================================================= 4 passed in 2.02 seconds =========================================================================================
(torch04) chenjun@chenjun-MS-7A71:~/warp-ctc/pytorch_binding/tests$ git commit

注意:

  1. pytorch0.4.1 不需要做以上修改, 0.4.0才需要做以上修改
  2. pytorch0.4.0 修改之后,在测试gpu的时候,会一直卡着不动。

常见问题

  1. gcc编译的时候出现src/binding.cpp:6:29: fatal error: torch/extension.h: 没有那个文件或目录,详细错误见下图,这应该是项目不同版本引起的问题。
gcc -pthread -B /home/yangna/yangna/tool/anaconda2/envs/torch40/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/yangna/yangna/tool/warp-ctc/include -I/home/yangna/yangna/tool/anaconda2/envs/torch40/lib/python3.6/site-packages/torch/lib/include -I/home/yangna/yangna/tool/anaconda2/envs/torch40/lib/python3.6/site-packages/torch/lib/include/TH -I/home/yangna/yangna/tool/anaconda2/envs/torch40/lib/python3.6/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/yangna/yangna/tool/anaconda2/envs/torch40/include/python3.6m -c src/binding.cpp -o build/temp.linux-x86_64-3.6/src/binding.o -std=c++11 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_EXTENSION_NAME=warpctc_pytorch._warp_ctc
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
src/binding.cpp:6:29: fatal error: torch/extension.h: 没有那个文件或目录
compilation terminated.
error: command 'gcc' failed with exit status 1

解决: git checkout ac045b6, 切换分支

你可能感兴趣的:(pytorch)