pytorch中没有CTCloss,需要安装第三方库Warp-ctc
warp_ctc源码地址为:https://github.com/SeanNaren/Warp-ctc
按照官网上的步骤安装
git clone https://github.com/SeanNaren/warp-ctc.git cd warp-ctc mkdir build; cd build cmake .. make
安装
cd pytorch_binding python setup.py install
产生如下错误
root@localhost:/home/ocrtrain/train/ocr/warp-ctc/pytorch_binding# python setup.py install
running install
running bdist_egg
running egg_info
creating warpctc_pytorch.egg-info
writing warpctc_pytorch.egg-info/PKG-INFO
writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt
writing top-level names to warpctc_pytorch.egg-info/top_level.txt
writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'
reading manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'
writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib.linux-x86_64-3.6
creating build/lib.linux-x86_64-3.6/warpctc_pytorch
copying warpctc_pytorch/init.py -> build/lib.linux-x86_64-3.6/warpctc_pytorch
running build_ext
building 'warpctc_pytorch._warp_ctc' extension
creating build/temp.linux-x86_64-3.6
creating build/temp.linux-x86_64-3.6/src
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/michael/ocrtrain/train/ocr/warp-ctc/include -I/usr/local/lib/python3.6/dist-packages/torch/include -I/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.6/dist-packages/torch/include/TH -I/usr/local/lib/python3.6/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.6m -c src/binding.cpp -o build/temp.linux-x86_64-3.6/src/binding.o -std=c++11 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0
src/binding.cpp:10:11: fatal error: ATen/cuda/CUDAGuard.h: No such file or directory
#include "ATen/cuda/CUDAGuard.h"
^~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1
我用的是cuda10,版本应该是只支持到cuda9,所以报错找不到cuda头文件,我直接改为了cpu版的,
进入warp-ctc/pytorch_binding/src/目录, 找到binding.cpp文件
修改,把gpu部分去掉
#include
#include
#include
#include
#include "ctc.h"
int cpu_ctc(torch::Tensor probs,
torch::Tensor grads,
torch::Tensor labels,
torch::Tensor label_sizes,
torch::Tensor sizes,
int minibatch_size,
torch::Tensor costs,
int blank_label)
{
float* probs_ptr = (float*)probs.data_ptr();
float* grads_ptr = grads.storage() ? (float*)grads.data_ptr() : NULL;
int* sizes_ptr = (int*)sizes.data_ptr();
int* labels_ptr = (int*)labels.data_ptr();
int* label_sizes_ptr = (int*)label_sizes.data_ptr();
float* costs_ptr = (float*)costs.data_ptr();
const int probs_size = probs.size(2);
ctcOptions options;
memset(&options, 0, sizeof(options));
options.loc = CTC_CPU;
options.num_threads = 0; // will use default number of threads
options.blank_label = blank_label;
#if defined(CTC_DISABLE_OMP) || defined(APPLE)
// have to use at least one
options.num_threads = std::max(options.num_threads, (unsigned int) 1);
#endif
size_t cpu_size_bytes;
get_workspace_size(label_sizes_ptr, sizes_ptr,
probs_size, minibatch_size,
options, &cpu_size_bytes);
float* cpu_workspace = new float[cpu_size_bytes / sizeof(float)];
compute_ctc_loss(probs_ptr, grads_ptr,
labels_ptr, label_sizes_ptr,
sizes_ptr, probs_size,
minibatch_size, costs_ptr,
cpu_workspace, options);
delete[] cpu_workspace;
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cpu_ctc", &cpu_ctc, "CTC Loss function with cpu");
}
执行 python setup.py install 就可以编译运行了
第二种方法 参考
https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.zh_cn.md
https://blog.csdn.net/AMDS123/article/details/73433926
在warp-ctc根目录中运行“luarocks install http://raw.githubusercontent.com/baidu-research/warp-ctc/master/torch_binding/rocks/warp-ctc-scm-1.rockspec”。
测试一下, 则将warp-ctc/pytorch_binding/build/warpctc_pytorch 目录拷贝至与该py文件同级的目录下。vi test.py 然后 python test.py
import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
print('PyTorch bindings for Warp-ctc')
或者 cd warp-ctc/pytorch_binding/tests && python test_cpu.py
package = "warp-ctc" version = "scm-1" source = { url = "git://github.com/baidu-research/warp-ctc.git", } description = { summary = "Baidu CTC Implementation", detailed = [[ ]], homepage = "https://github.com/baidu-research/warp-ctc", license = "Apache" } dependencies = { "torch >= 7.0", } build = { type = "command", build_command = [[ cmake -E make_directory build && cd build && cmake .. -DLUALIB=$(LUALIB) -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) -j$(getconf _NPROCESSORS_ONLN) && make install ]], platforms = {}, install_command = "cd build" }
pytorch1.0 已经支持CTCloss了,
import torch loss = torch.nn.CTCLoss
不过有人说自带的CTCLoss有bug,暂时还没遇到过