当前场景文字识别较为主流的方法分为 attention mechanism(代表:Show, Attend and Read; Transformer-based attention; 各种各样的 2D Attention) 和 STN + CRNN / DenseNet + CTC,二者的主要区别是一个是在解码前给文字区域赋予较高的权重,聚焦于文本信息,弱化无关背景信息,另一个是在一开始修正曲形的文字得到水平规整的文字进行识别;
基于注意力机制的几篇前面博文都有涉及,感兴趣可以参考之前的文章:
由于最近组里相关项目涉及到基于 STN 的文字识别算法框架,所以大概总结一下相关模型结构。
基本框架:
intputs --> STN --> Feture Extraction (CNN) --> Sequence Modeling(optional) --> Prediction(Decoder) --> outputs
1. 总体概览
以较为经典的 TPS-STN 为例,定位网(localization network)络定位出一系列基准点(fiducial points),格点生成器(grid generator)依据这些基准点生成对应的一系列采样点(sampling grid),采样器通过匹配两组基准点和变换矩阵 生成基于原图文字区域的格点网络,最后通过双线性差值得到修正的图像,一般和原图尺寸相同;
2. 定位网络 localization network
通过CNN回归一系列坐标点(x-y),通过激活函数 tanh 归一化为[-1, 1],共 2K 个值, 确定 K 个基准点, 记为;
注意 K 为常亮;
3. 格点生成器 grid generator
首先定义另一组 K 个基准点(base fiducial points),
均匀分布在修正的照片的上下两侧,下图中左边绿色的点是
,右边的蓝色的点就是
;通过变化矩阵T完成从左图到有图的变化;
,
表示输出图像的每一个像素点的坐标(x, y),
为输出(输入)图像像素个数;
,
表示格点生成器对应原图中生成的格点坐标(x, y),
为输出(输入)图像像素个数;
是通过格点生成器生成的 上的像素值,
4. 采样器 sampling grid
,
表示双线性差值,得到从输入图像到输出图像的变换;
5. 总结
整个 TPS-STN 结构是可微的,所以可以通过反向传播反传梯度自动学习,是一个无监督的过程,可以应用在倾斜或小曲率的曲形文字修正上。
Encoder: CNN+Sequence Modeling, CNN 的选择有很多,目前效果比较好的就是 ResNet 和 DenseNet;Sequence Modeling 当前比较主流的做法就是使用 BiLSTM 进行序列建模,关联特征向量的上下文信息;这部分没太多好说的,要注意的是 DenseNet 的中间层特征图是跨层连接,所以具有全局感受野,此外其网络规模相对同尺度 ResNet 更小,但训练更占显存;
选用 CTC 进行预测输出的话,其输入必为列方向上一维的特征向量,引入blank空白字符,对于每一列向量预测一个字符,最后通过 beta-decode 删除空白字符和重复字符;
上图 RNN 每一步的输出其实都是一组概率分布,, 对于第一个矩形框,输出可能是
;
定义:
捕捉decoder输入序列的信息,给其赋予不同权重,最后再输出最终的字符,一般是每次输出的预测字符的一部分输入来自于上一个字符的输出,第一个输出字符固定为 BOS_TOKEN,输出到 EOS_TOKEN 为止;
其他:
字符识别的beam search只在inference时(train的时候label已知),且只用于输出时存在前后依赖关系的网络结构中使用;
这种情况下基于贪心的策略不是很合理,因此beam search的做法是选定一个K值(K > 1),每次取前K个概率最大的字符和前面保留的前K个概率最大的字符串(长度为n)排列组合,再得到概率最大的长度为n+1的字符串,依次向后预测和生成。