序列识别问题的聚合交叉熵损失函数(ACE损失函数)

文本识别算法介绍

文本识别问题是一个经典的序列预测问题,他输入一个有序列信息的三维图像,输出一个预测序列。常用的文本识别框架为CNN+BiLSTM+CTC,和CNN+BiLSTM+Attention。经过CNN+BiLSTM将三维图像提取特征,得到2维的特征序列(T*C),然后通过CTC或Attention将特征序列转化为预测结果。
给定一张来自于训练集Q的图像I,它的文本标签S,文本所包含的类别{1, 2, · · · , |C|},这张图像文本序列的长度L,通常的文本识别问题的损失函数为:

L ( w ) = − ∑ ( I , S ) ∈ Q l o g ( P ( S ∣ I ; w ) ) = − ∑ ( I , S ) ∈ Q ∑ l = 1 L l o g ( P ( S l ∣ l , I ; w ) ) L(w)=-\sum_{(I,S)\in Q}log(P(S|I;w))=-\sum_{(I,S)\in Q}\sum_{l=1}^{L}log(P(S_{l}|l,I;w)) L(w)=(I,S)Qlog(P(SI;w))=(I,S)Ql=1Llog(P(Sll,I;w))

其中 P ( S l ∣ l , I ; w ) P(S_{l}|l,I;w) P(Sll,I;w)表示在被预测序列的第l个字符预测结果为 S l S_{l} Sl的条件概率。
对于上述公式的计算非常困难,因为得到的文本特征与标签文本序列存在不对齐的问题,实际上不能直接使用第二个式子。
CTC和Attention分别从两个方面解决了这个问题。

  • CTC是将得到的特征序列每一个时刻都预测后直接去掉其中的空格,仅得到剩下的预测字符序列,然后将这个字符序列与标签序列计算交叉熵损失。但是他的关键在于如何反向传播,CTC通过隐马尔科夫模型中的前向后向算法从标签序列倒推回可能得到这个标签的预测序列,这些序列的概率权值是不同的,然后再进行后续的反向传播。CTC层本身是没有变量的,因此他训练的目的是使特征提取时能够学会文本序列的排布信息特征。它的损失计算是基于第一个公式 − ∑ ( I , S ) ∈ Q l o g ( P ( S ∣ I ; w ) ) -\sum_{(I,S)\in Q}log(P(S|I;w)) (I,S)Qlog(P(SI;w))
  • Attention是直接增加了一个可训练的解码层,将非常长的特征序列解码为长的不同的预测文本。这相当于是将CTC中的前向反向算法替换为一种参数可训练的注意力机制。通过训练能够得到要预测某个位置的文本字符,需要从哪些特征中得到。相当于Attention层本身学习到了一种文本的排布信息特征。它的损失计算是基于第二个公式 − ∑ ( I , S ) ∈ Q ∑ l = 1 L l o g ( P ( S l ∣ l , I ; w ) ) -\sum_{(I,S)\in Q}\sum_{l=1}^{L}log(P(S_{l}|l,I;w)) (I,S)Ql=1Llog(P(Sll,I;w))

ACE交叉熵损失

本文提出了一种新颖的损失函数,这个损失函数不考虑序列中字符间的顺序,仅仅考虑一个字符串中某个类别的字符出现的次数。
我们在没有Attention机制的网络中直接计算 − ∑ ( I , S ) ∈ Q ∑ l = 1 L l o g ( P ( S l ∣ l , I ; w ) ) -\sum_{(I,S)\in Q}\sum_{l=1}^{L}log(P(S_{l}|l,I;w)) (I,S)Ql=1Llog(P(Sll,I;w)) 是一种错误的做法,因为存在字符序列与特征序列的错位不对齐。基于此,文中提出了一种不需要考虑对齐的方案,即不考虑特征的顺序,仅仅计算各类别字符出现次数。

L ( w ) = − ∑ ( I , S ) ∈ Q ∑ l = 1 L l o g ( P ( S l ∣ l , I ; w ) ) ≈ − ∑ ( I , S ) ∈ Q ∑ k = 1 ∣ C ∣ l o g ( P ( N k ∣ k , I ; w ) ) L(w)=-\sum_{(I,S)\in Q}\sum_{l=1}^{L}log(P(S_{l}|l,I;w))\approx -\sum_{(I,S)\in Q}\sum_{k=1}^{|C|}log(P(N_{k}|k,I;w)) L(w)=(I,S)Ql=1Llog(P(Sll,I;w))(I,S)Qk=1Clog(P(Nkk,I;w))

其中|C|表示类别数, P ( N k ∣ k , I ; w ) P(N_{k}|k,I;w) P(Nkk,I;w) 表示在图像I的预测结果中,第k个类别的字符出现的次数等于标签中给定次数 N k N_{k} Nk 的条件概率。
例如标签字符串为students,则损失函数的目标是,使识别结果的s,t出现两次,其他类出现一次(包括空白类)。

基于回归的ACE损失函数

我们通过CNN+BiLSTM得到的特征序列维度为(T * K),其中T为序列长度,K为字符类别数,我们定义输出的特征序列张量为Y,第t个时刻的特征向量为 y t y^{t} yt,第t个时刻第k个类别的预测概率为 y k t y_{k}^{t} ykt。整个字符序列中所有位置第k个类别出现的总概率为 y k = ∑ t = 1 T y k t y_{k}=\sum_{t=1}^{T}y_{k}^{t} yk=t=1Tykt
我们定义 y k y_{k} yk N k N_{k} Nk 的平方损失(回归损失):

m a x ∑ k = 1 ∣ C ∣ l o g ( P ( N k ∣ k , I ; w ) ) ⇔ m i n ∑ k = 1 ∣ C ∣ ( N k − y k ) 2 max\sum_{k=1}^{|C|}log(P(N_{k}|k,I;w))\Leftrightarrow min\sum_{k=1}^{|C|}(N_{k}-y_{k})^2 maxk=1Clog(P(Nkk,I;w))mink=1C(Nkyk)2

数据集的损失函数表示为:

L ( w ) = 1 2 ∑ ( I , S ) ∈ Q ∑ k = 1 ∣ C ∣ ( N k − y k ) 2 L(w)=\frac{1}{2}\sum_{(I,S)\in Q}\sum_{k=1}^{|C|}(N_{k}-y_{k})^2 L(w)=21(I,S)Qk=1C(Nkyk)2

T表示预测文本长度,|S|表示标签文本长度,我们用(T-|S|)表示字符串中空白字符的个数 N ϵ = T − ∣ S ∣ N_{\epsilon }=T-|S| Nϵ=TS

ACE回归损失梯度

首先损失 L ( w ) L(w) L(w) 对输出 y k t y_{k}^{t} ykt 求导

∂ L ( w ) ∂ y k t = ∂ L ( w ) ∂ y k ∂ y k ∂ y k t = ( y k − N k ) = ( ∑ t = 1 T y k t − N k ) \frac{\partial L(w)}{\partial y_{k}^{t}}=\frac{\partial L(w)}{\partial y_{k}}\frac{\partial y_{k}}{\partial y_{k}^{t}}=(y_{k}-N_{k})=(\sum_{t=1}^{T}y_{k}^{t}-N_{k}) yktL(w)=ykL(w)yktyk=(ykNk)=(t=1TyktNk)

其中 y k t y_{k}^{t} ykt 由softmax层得到,

y k t = e a i ∑ j e a j y_{k}^{t}=\frac{e^{a_{i}}}{\sum_{j}e^{a_{j}}} ykt=jeajeai

y i y_{i} yi a i a_{i} ai 求导得到

∂ y k t ∂ a i = y i ( δ i j − y j ) \frac{\partial y_{k}^{t}}{\partial a_{i}}=y_{i}(\delta_{ij}-y_{j}) aiykt=yi(δijyj)

其中 当 i = j i=j i=j时, δ i j = 1 \delta_{ij}=1 δij=1,否则 δ i j = 0 \delta_{ij}=0 δij=0

最终ACE回归损失梯度表示为:

∂ L ( I , S ) ∂ a k t = ∑ k ′ = 1 ∣ C ∣ ∂ L ( I , S ) ∂ y k ′ t ∂ y k ′ t ∂ a k t = ∑ k ′ = 1 ∣ C ∣ ( y k ′ − N k ) ∗ y k ′ t ( δ k k ′ − y k t ) = ( y k − N k ) ∗ y k t ( 1 − y k t ) − ∑ k ′ = 1 , k ≠ k ′ ∣ C ∣ ( y k ′ − N k ) ∗ y k ′ t y k t \frac{\partial L(I,S)}{\partial a_{k}^{t}}=\sum_{k'=1}^{|C|}\frac{\partial L(I,S)}{\partial y_{k'}^{t}}\frac{\partial y_{k'}^{t}}{\partial a_{k}^{t}}=\sum_{k'=1}^{|C|}(y_{k'}-N_{k})*y_{k'}^{t}(\delta_{kk'}-y_{k}^{t})=(y_{k}-N_{k})*y_{k}^{t}(1-y_{k}^{t})-\sum_{k'=1,k\neq k'}^{|C|}(y_{k'}-N_{k})*y_{k'}^{t}y_{k}^{t} aktL(I,S)=k=1CyktL(I,S)aktykt=k=1C(ykNk)ykt(δkkykt)=(ykNk)ykt(1ykt)k=1,k̸=kC(ykNk)yktykt

回归损失的梯度消失

上面的回归损失函数存在着梯度消失问题,在训练开始的几个阶段,我们的输出对每个类别都有着平均的输出,即 y k ′ t = 1 / ∣ C ∣ {y_{k'}^{t}=1/|C|} ykt=1/C。当我们的类别数C比较大时,例如汉字识别,类别数高达数千,此时 y k ′ t y_{k'}^{t} ykt 的数量级是 1 0 − 3 10^-3 103,上面公式中,数量级大约是 y k ′ t 2 {y_{k'}^{t}}^{2} ykt2,即 1 0 − 6 10^-6 106,即相当小的梯度更新,完全无法训练。
即使我们的类别数并没有那么多,但是我们的梯度是 y k ′ t 2 {y_{k'}^{t}}^{2} ykt2 ,它再对前面的层求导,每一次都会乘 y k ′ t y_{k'}^{t} ykt,即梯度将会以指数级减小,梯度消失问题。

基于交叉熵的ACE损失函数

我们将网络预测的各类别字符数量当作一个概率分布, y k ‾ = y k / T \overline{y_{k}}=y_{k}/T yk=yk/T,将标签各类别字符数量当作另一个概率分布, N k ‾ = N k / T \overline{N_{k}}=N_{k}/T Nk=Nk/T
我们使用交叉熵函数表示预测结果分布和标签分布的相似程度:
L ( I , S ) = − ∑ k = 1 ∣ C ∣ N k ‾ ∗ l n y k ‾ L(I,S)=-\sum_{k=1}^{|C|}\overline{N_{k}}*ln\overline{y_{k}} L(I,S)=k=1CNklnyk

这个损失函数对softmax之前的logits a k t a_{k}^{t} akt 求梯度:

∂ L ( I , S ) ∂ a k t = ∑ k ′ = 1 ∣ C ∣ ∂ L ( I , S ) ∂ y k ‾ ∂ y k ‾ ∂ y k ′ t ∂ y k ′ t ∂ a k t = ∑ k ′ = 1 ∣ C ∣ − N k ‾ y k ‾ ∗ 1 T ∗ y k ′ t ( δ k k ′ − y k t ) = − 1 T ∗ ∑ k ′ = 1 ∣ C ∣ N k ‾ ∗ y k t y k ‾ ∗ ( δ k k ′ − y k t ) \frac{\partial L(I,S)}{\partial a_{k}^{t}}=\sum_{k'=1}^{|C|}\frac{\partial L(I,S)}{\partial \overline{y_{k}}}\frac{\partial \overline{y_{k}}}{\partial y_{k'}^{t}}\frac{\partial y_{k'}^{t}}{\partial a_{k}^{t}}=\sum_{k'=1}^{|C|}-\frac{\overline{N_{k}}}{\overline{y_{k}}}*\frac{1}{T}*y_{k'}^{t}(\delta_{kk'}-y_{k}^{t})=-\frac{1}{T}*\sum_{k'=1}^{|C|}\overline{N_{k}}*\frac{y_{k}^{t}}{\overline{y_{k}}}*(\delta_{kk'}-y_{k}^{t}) aktL(I,S)=k=1CykL(I,S)yktykaktykt=k=1CykNkT1ykt(δkkykt)=T1k=1CNkykykt(δkkykt)

交叉熵损失

在上述公式中, N k N_{k} Nk 是常数, ( δ k k ′ − y k t ) (\delta_{kk'}-y_{k}^{t}) (δkkykt) y k t y_{k}^{t} ykt 的线性函数,损失函数主要取决于 y k t y k ‾ \frac{y_{k}^{t}}{\overline{y_{k}}} ykykt,我们希望他尽可能是常数级的。

  • 在初始训练阶段,不同时刻t,不同类别k均匀分布,此时 y k ‾ = y k / T = ∑ t = 1 T y k t / T ≈ y k t , y k t y k ‾ = 1 \overline{y_{k}}=y_{k}/T=\sum_{t=1}^{T}y_{k}^{t}/T \approx y_{k}^{t},\frac{y_{k}^{t}}{\overline{y_{k}}}=1 yk=yk/T=t=1Tykt/Tyktykykt=1
  • 在随后的训练阶段,不同时刻t,某一个类别k’的概率占主要部分,而其他类别非常小,,此时 y k ‾ = y k / T = ∑ t = 1 T y k t / T ≈ y k t / T , y k t y k ‾ = T \overline{y_{k}}=y_{k}/T=\sum_{t=1}^{T}y_{k}^{t}/T \approx y_{k}^{t}/T,\frac{y_{k}^{t}}{\overline{y_{k}}}=T yk=yk/T=t=1Tykt/Tykt/Tykykt=T

可以看到这个值基本上是1~T的常量。

2维预测问题

很多的文本呈二维的分布在图片上,例如一些不规则行文本,弯曲,仿射,多行文本等。这些问题使用传统的方法无法有效解决,在这里我们可以使用ACE损失函数解决。ACE损失函数可以很自然的应用于这些文本识别,因为他并不考虑文本的顺序,而仅仅考虑文本出现的次数或者频率,这在2维图像上也是可以计算的。
假设输出的2维预测图高度H,宽度W(经过CNN,不等于原图大小),第h行第w列的预测输出表示为 y k h w y_{k}^{hw} ykhw ,我们定义
y k ‾ = y k W ∗ H = ∑ w = 1 W ∑ h = 1 H y k h w W ∗ H , N k ‾ = N k H ∗ W \overline{y_{k}}=\frac{y_{k}}{W*H}=\frac{\sum_{w=1}^{W}\sum_{h=1}^{H}y_{k}^{hw}}{W*H} , \overline{N_{k}}=\frac{N_{k}}{H*W} yk=WHyk=WHw=1Wh=1HykhwNk=HWNk

损失函数表示为

L ( I , S ) = − ∑ k = 1 ∣ C ∣ N k ‾ ∗ l n y k ‾ = − ∑ k = 1 ∣ C ∣ N k H ∗ W ∗ l n y k W ∗ H L(I,S)=-\sum_{k=1}^{|C|}\overline{N_{k}}*ln\overline{y_{k}}=-\sum_{k=1}^{|C|}\frac{N_{k}}{H*W}*ln\frac{y_{k}}{W*H} L(I,S)=k=1CNklnyk=k=1CHWNklnWHyk

我们直接将原始的2维预测拉直为1维预测结果,并计算损失。

实验评估

本文在自然场景文本识别,离线手写字符识别,日常场景目标计数三个任务中进行实验评估。我们分别使用1维和2维方法进行预测,得到的预测结果分别为H的特征序列和W*H的特征图。

场景文本识别

本文使用两种类型的文本识别数据集,规则文本如iiit5k,SVT,ICDAR2003,ICDAR2013,不规则文本如ICDAR2015,CUTE80,SVT-Perspective。规则数据集用于研究ACE损失函数1维预测,不规则数据集用于研究2维预测。

实现细节

在规则数据集上的1维文本识别基于网络CRNN,在synth80k的800万合成数据集上训练。

在不规则数据集上的2维文本识别基于网络ResNet-101,conv1被替换为3*3,步长1,conv4_x作为输出,训练数据集来自800万合成数据集和400万张从8万大图中裁剪下来的包含文本的数据集。所有的输入图像都被resize和padding到(96,100)大小,并且输出预测图大小(12,13),相当于8倍解析度下采样。我们然后将(12,13)的2维预测图拉直为12x13的一维预测序列,并使用ACE损失函数。

实验结果

回归损失与交叉熵损失

我们对规则文本进行1维预测,分别使用ACE回归损失与交叉熵损失。
回归损失存在梯度消失的问题,前面的一些层参数无法训练到,虽然回归损失能够收敛,但是收敛最终的单词错误率与字符错误率都非常高;交叉熵损失能够最终收敛到一个非常高的水平。与原始的CRNN网络相比较,表现有略微的提升。

不规则文本

我们在不规则文本上使用2维预测,仅仅采用ResNet-101的CNN网络,没有加入LSTM等序列信息。最终我们发现这个模型在CUTE和ICDAR15数据集上有非常好的效果,尤其是CUTE,这个数据集的图像都拥有高解析度,弯曲严重,严重的不规则文本,非常适合ACE的2维预测。网络模型在没有字典的SVTP数据集上效果一般,因为这个数据集图像解析度相当低,仅仅使用CNN网络而不采用LSTM很难提取解析度如此低的文本特征。

我们可视化最终的12x13大小的预测图,能够发现,在2维空间中预测结果字符与原始图像中文本有着非常相似的分布。

你可能感兴趣的:(神经网络模型)