端到端的OCR:LSTM+CTC的实现

前面提到了用CNN来做OCR。这篇文章介绍另一种做OCR的方法,就是通过LSTM+CTC。这种方法的好处是他可以事先不用知道一共有几个字符需要识别。之前我试过不用CTC,只用LSTM,效果一直不行,后来下决心加上CTC,效果一下就上去了。

CTC是序列标志的一个重要算法,它主要解决了label对齐的问题。有很多实现。百度IDL在16年初公开了一个GPU的实现,号称速度比之前的theano-ctc, stanford-ctc都要快。Mxnet目前还没有ctc的实现,因此决定吧warpctc集成进mxnet。

根据issue里作者们的建议,决定和集成torch一样,写一个plugin,因此C++代码放在plugin/warpctc目录中。整个集成任务其实就是写一个wrapctc的op。代码在 plugin/warpctc/warpctc-inl.h.

CTC这一层其实和SoftmaxOutput很像。其实他们的forward的实现就是一模一样的。唯一的差别就是backward中grad的实现,在这里需要调用warpctc的compute_ctc_loss函数来计算梯度。实际上warpctc的主要接口也就是这个函数。

下面说说具体怎么用lstm+ctc来做ocr的任务。详细的代码在 examples/warpctc/lstm_ocr.py。这里只说说大体思路。

假设我们要解决的是4位数字的识别,图片是80*30的图片。那么我们就将每张图片按列切分成80个30维的向量。然后作为一个lstm的80个输入。一个lstm的输出和输入数目应该是相同的。而我们的预测目标却只有4个数字。而不是80个数字。在没有用ctc时我想了两个解决方案。第一个是用encode-decode模式。也就是80个输入做encode,然后decode成4个输出。实测效果很挫。第二个是把4个label每个copy20遍,从而变成80个label。实测也很挫。没办法,最后只能用ctc loss了。

用ctc loss的体会就是,如果input的长度远远大于label的长度,比如我这里是80和4的关系。那么一开始的收敛会比较慢。在其中有一段时间cost几乎不变。此刻一定要有耐心,最终一定会收敛的。在ocr识别的这个例子上最终可以收敛到95%的精度。

目前代码还在等待merge。pull request。

---------------

欢迎关注 微信公众号【ResysChina】

你可能感兴趣的:(端到端的OCR:LSTM+CTC的实现)