图像文字识别初探(二)-FAN(Focusing Attention Network)

图像文字识别初探(一)-CRNN(Convolution Recurrent Neural Network)和DTRN(Deep-text Recurrent Network)

图像文字识别初探(二)-FAN(Focusing Attention Network)

图像文字识别初探(三)-Mask TextSpotterr

图像文字识别初探(四)-- single-shot text detector

FAN(Focusing Attention Network)

论文:Focusing Attention: Towards Accurate Text Recognition in Neural Images

此篇论文是海康威视联合复旦和上交发表的,论文中提出FAN(Focusing Attention Network)算法,FAN包括两个主要的组件,AN(attention network)用来识别目标字符和FN(focusing network)通过检查AN的注意力是否在图片中的目标字符区域,然后调整注意力。整个FAN网络架构包括一个基于ResNet网络的特征提取,一个AN网络和一个FN网络,如下图所示,CNN-BLSTM从输入图像I中提取特征向量序列,AN使用提取的特征输出alignment factor和glimpse vector,FN聚焦AN中的注意力在输入图像中合适的目标字符区域。

图像文字识别初探(二)-FAN(Focusing Attention Network)_第1张图片

Attention Network

一般来说,基于注意力机制的文本识别通常是encoder-decoder架构,在encoder阶段,一张输入图像经过CNN/LSTM转换成特征向量序列,序列中每一个特征向量对应输入图像中的一个区域,在此论文中,此区域为attention regions; 在decoder阶段,AN网络首先计算alignment factor,factor的计算牵涉到the history of target characters和the encoded feature vectors这两个变量,于是就可以使得attention region和相对应的ground-truth-labels对齐,接下来,基于glimpse vectors和 the history of target characters,一个RNN网络就可以用来生成目标字符了。

对于一个输入图像I,图像首先被一个CNN/LSTM编码成特征向量序列,Encoder\left ( I \right ) = \left ( h_{1}, \cdots, h_{T} \right ),这时Encoder阶段; 在第t步,decoder输出y_{t}=Generate\left ( s_{t}, g_{t} \right ),其中s_{t}是RNN第t个时刻的隐含状态s_{t}=RNN\left ( y_{t-1}, g_{t}, s_{t-1} \right )g_{t}是编码器输出的特征向量序列\left ( h_{1}, \cdots, h_{T} \right )的加权和g_{t}=\sum_{j=1}^{T}\alpha _{ij}h_{j},其中\alpha _{ij}为algiment factor。对\alpha _{ij}的计算是通过对\left ( h_{1}, \cdots, h_{T} \right )中的每一个元素的得分来评估,并对得分进行归一化:

                                                                      \\ e_{t,j}=v^{T}tanh\left ( Ws_{t-1}+Vh_{j}+b \right ) \\ \alpha_{t,j}=\frac{exp\left ( e_{t,j} \right )}{\sum_{j=1}^{T}exp\left ( e_{t,j} \right )}

其中v, W,V和b全是可训练的参数。

由于解码器需要输出一个变长的序列,在目标集合中,添加一个特殊的字符end-of-sentence(EOS),当出现EOS时,解码器完成字符的输出。整个attention model的损失函数为L_{Att}=-\sum_{t} \ln P\left (\hat{y}_{t} | I, \theta \right ),其中\hat{y}_{t}是第t个ground truth字符。

在场景文字识别中,AN模型主要有两个方面的缺点,其一是attention drift问题,其二是在巨大的场景文本数据下很难训练一个这样的模型。下面我们主要探讨一下attention drift问题。由于低质量(如模糊,污损和噪音等)图片和一些复杂图片(如扭曲或者重叠字符,不同字符,不用尺寸,不同颜色或者复杂的背景)的影响,模型在glimpse vector的整合上没有对齐约束,产生不正确的alignment factor,导致注意力区域和标签区域错误匹配,就是所谓的attention drift。如下图中的子图片(a)所示,为了解决这个attention drift问题,此篇论文中加入了FN网络。

图像文字识别初探(二)-FAN(Focusing Attention Network)_第2张图片

Focusing Network

FN网络的引入是为了解决attention drift问题,,如下图所示,它主要分为两大步:

  1. 计算每一个预测lable的中心注意点(attention center)
  2. 通过生成在注意力区域(attention region)的概率分布将注意力集中在目标区域(target region)

图像文字识别初探(二)-FAN(Focusing Attention Network)_第3张图片

计算attentiopn center

前提条件:在attention model中,每一个特征向量与原图中的一区域相对应。在卷积和池化操作中,定义输入特征图为N \times D_{i} \times H_{i} \times W_{i},输出特征图为N \times D_{o} \times H_{o} \times W_{o},其中N,D,W,H分别表示batch_size,通道数和特征图的宽和高。依据卷积里的计算,有

                                                                     \\ H_{o}=\frac{H_{i} + 2 * pad_{H} - kernel_{H}}{stride_{H}} + 1 \\ W_{o}=\frac{W_{i} + 2 * pad_{W} - kernel_{W}}{stride_{W}} + 1

假设在L层的坐标点(x,y),我们计算它在L-1层的recetive field为bounding box坐标r = \left ( x_{min}, x_{max}, y_{min}, y_{max} \right ),计算公式为:

图像文字识别初探(二)-FAN(Focusing Attention Network)_第4张图片

在第t步,我们通过递归地计算上述公式得到h_{j}在输入图像中receptive field,并且选择此receptive field的中心作为attention centerc_{t,j}=location\left ( j \right ),其中j是h_{j}的下标,location是计算receptive field中心的函数,因此对目标字符y_{t}在输入图像中的attention center 为c_{t}=\sum_{j=1}^{T}\alpha_{t,j}c_{t,j}

将注意力集中在目标区域

计算目标字符y_{t}的attention center后,我们从输入图像或者一个卷积层输出裁剪一块P\left ( P_{H},P_{W} \right )大小的特征图,F_{t}=Crop\left ( F, c_{t}, P_{H},P_{W} \right ),其中F为图像或者卷积特征图,P是在输入图像中ground-truth区域的最大值。得到裁剪后的特征图后,计算每一个attention region的energy distribution,e_{t}^{\left ( i,j \right )}=tanh\left ( Rg_{t} + S F_{t}^{\left ( x,y \right )} +b\right ),其中R,S,b都是可训练的参数,(i,j)则表示第i \times P_{W} + j特征向量,那么被选中的区域的概率分布为:

                                                                      P_{t}^{\left ( i,j,k \right )}=\frac{exp\left ( e_{t}^{\left ( i,j,k \right )} \right )}{\sum_{{k}'}^{K} exp \left ( e_{t}^{\left ( i,j,{k }' \right )} \right )}

其中K是类别标签的数目

FN的focusing 损失函数为L_{focus} = - \sum_{t}^{M} \sum_{i}^{P_{W}} \sum_{j}^{P_{H}} log P\left ( \hat{y}_{t}^{\left ( i,j \right )}|I,\omega \right ),其中\hat{y}_{t}^{\left ( x,y \right )}是ground-truth pixel label,\omega是FN的全部参数

所以整个FAN的损失函数为:

                                                             L = \left ( 1 - \lambda \right )L_{Att} + \lambda L_{focus}

基于attention的加码器利用学习到的字符概率统计输出字符序列。对于无约束的文本识别(无字典),直接选择最大概率的字符;对于有约束的文本识别,根据不同尺寸的字典,计算字典中所有字符的条件概率,选择最高概率的作为输出

总结

  • 有效地对齐注意力区域和图像中的目标区域,成功解决了attention drift问题
  • 能准确识别复杂或者低质量图像中的文字
  • 带字典和不带样本的样本都可以

 

你可能感兴趣的:(图像文字识别,机器学习,深度学习,机器学习和深度学习之旅)