CTC可以生成一个损失函数,用于在序列数据上进行监督式学习,不需要对齐输入数据及标签,经常连接在一个RNN网络的末端,训练端到端的语音或文本识别系统。CTC论文
本文主要是讲解用wrap_ctc来实现pytorch版本的CRNN的环境配置过程,用其来进行OCR端到端文本识别。(注:wrap_ctc是百度开源的一个模块,需要自己编译使用。在pytorch 1.0中,自带了CTC loss,用pytorch 1.0的话可以不用编译这个wrap_ctc)
CTC网络的输入是一个样本(图像)经过网络(一般是CNN+RNN)计算后生成的特征向量(特征序列),这部分可参考CRNN论文。
特征序列里各个向量是按序排布的,是从图像样本上从左到右的一个个小的区间映射过来的,可以设置区间的大小(宽度),宽度越小,获得的特征序列里的特征向量个数越多,极端情况下,可以设置区间宽度为1,这样就会生成width(图像宽度)个特征向量(作为后续RNN的输入)。
将CNN产生的一系列(假设为width个)特征序列作为后续RNN(在CRNN中用的是Bi-LSTM)的输入,可以得到一个width维的概率矩阵,这个概率矩阵就可以作为CTC的输入,用来计算CTC loss。
CTC网络的计算是为了得到特征序列最可能对应的标签对象,对语音识别是一段话,对文本识别是一段文字。
计算特征序列里每个特征向量(共N个)分别对应的n个可能结果的概率。如果当前的特征向量的预测结果不在样本标签列表里,就置预测结果为blank空格或下划线。计算结果从一个N维的特征序列,得到一个N×n的概率矩阵(就是上面所说的)。
计算上述预测的N×n的概率矩阵的所有可能结果的概率,中间涉及到去除重复字母和blank的操作。N个n维的特征向量(即N×n的概率矩阵)对应的所有可能的结果有 N n N^{n} Nn个,涉及到组合学,计算所有可能概率的成本会很高,但是CTC运用了动态规划(前后向算法,这部分推荐看一下HMM)以大幅降低计算的复杂性。
对识别过程,取出最大概率对应的结果作为识别结果输出;
对训练过程,取最大概率对应的结果跟真实标签之间的差异(计算编辑距离等方法),作为训练Loss,反向传输给前端网络。
CTC计算过程示意图:
默认你会了。。。
首先我用的是环境是python 3.6,pytorch 1.0.1,起码这个版本是没有问题的。
参考crnn-train-pytorch,这个就是官方的教程,这部分github上的教程没有问题,如果出现错误,可以去Issues中看一下有没有相同错误,我当时安装的时候也是出现了错误,在issues里面找到了答案,出问题一定要去看一下。下面我贴几个我编译时出现问题,然后找到的答案,希望能帮到大家。
每个人遇到的问题可能不一样,上面仅供参考,相信出问题issues里面大概率会有解决办法的。
注:上面第三个链接是原版wrap_ctc里面的Issues,这里面也可能有解决问题的答案。
warp-CTC是百度开源的一个可以应用在CPU和GPU上高效并行的CTC代码库,对CTC算法进行了并行处理。
warp-CTC安装:
git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding
python setup.py install
添加环境变量:
vim ~/.bashrc
export WARP_CTC_PATH=/home/xxx/warp-ctc/build
验证pytorch中warp-CTC是否可用GPU例子:
cd /home/xxx/warp-ctc/pytorch_binding/tests
python test_gpu.py
import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
print('PyTorch bindings for Warp-ctc')
到此整个CRNN的运行环境基本已经配置完毕(还需要安装一个lmdb,pip install lmdb即可),在编译wrap_ctc的时候可能会出一些莫名奇妙的Bug,这部分还是推荐先去Issues里面找答案。像这么成熟的模块,Issues里面基本已经包含所有常见错误了,其他有问题欢迎大家多交流,共同进步~