CTC原理介绍

    在做OCR时用到了CTC Loss,对CTC Loss一直都是只有宏观的概念,并没有认真研究它的细节原理(主要是没勇气研究),最近由于需要修改CTC中的解码部分,所以又硬着头皮看论文,查资料,经过不懈努力,总算是明白了一点。接下来我将按照我的理解对CTC的损失计算、解码进行详细说明,限于本人水平有限,不对之处,敬请指正。

CTC出现的背景

    在序列学习任务中,RNN对训练样本一般有这样的依赖条件:输入序列和输出序列之间的映射关系已经事先标注好了,可以根据输出序列和标注样本间的差异来直接定义RNN模型的Loss函数。比如,在词性标注任务中,训练样本中每个词(或短语)对应的词性会事先标注好。
    但是,在OCR、语音识别时,由于我们很难对样本的输入进行标注(很难区分相邻信息间的分界线),所以仅使用RNN是很难解决这些问题的。这时Alex Graves等人在ICML 2006上提出的一种端到端的RNN训练方法Connectionist Temporal Classification(CTC),它可以让RNN直接对序列数据进行学习,而无需事先标注好训练数据中输入序列和输入序列的映射关系,使得RNN模型在语音识别等序列学习任务中取得更好的效果,在语音识别和图像识别等领域CTC算法都有很比较广泛的应用。
CTC原理介绍_第1张图片

CTC介绍

    假设RNN(一般都是经过了 s o f t m a x softmax softmax的)的某一条输出为 π = { π 1 , π 2 , . . π n } \pi=\{\pi_1,\pi_2,..\pi_n\} π={π1,π2,..πn},对应的标签为 I I I,m <= n,CTC的目的就是为了将 π \pi π通过一个函数B映射成 I I I,即: I = B ( π ) I=B(\pi) I=B(π) y π t t y_{\pi_t}^{t} yπtt表示在 t t t时刻输出为 π t \pi_t πt的概率。
    假设每个输出之间都是相互独立的,则其中一条符合 l = B ( π ) l=B(\pi) l=B(π)的路径的概率为: p ( π ∣ x ) = ∏ t = 1 T y π t t p(\pi|x)=\prod_{t=1}^{T}y_{\pi_t}^{t} p(πx)=t=1Tyπtt
    够映射成 I I I的概率为: p ( I ∣ x ) = ∑ π ∈ B − 1 ( I ) p ( π ∣ x ) p(I|x)=\sum_{\pi\in B^{-1}(I)}p(\pi|x) p(Ix)=πB1(I)p(πx),这里的 π ∈ B − 1 ( I ) \pi\in B^{-1}(I) πB1(I)是指所有能够映射成 l l l π \pi π
CTC原理介绍_第2张图片
    由于直接暴力计算 p ( I ∣ x ) p(I|x) p(Ix)的复杂度非常高,作者借鉴HMM的Forward-Backward算法思路,利用动态规划算法求解。
    在正式介绍前向,后向算法之前,我们先说明一些条件,方便后续的理解。

  1. 将目标序列 I I I转化为label,在目标序列的首尾和中间都加上空格,用 l ′ l^{'} l表示。如上图所示:我们的目标序列是:CAT,将CAT的首尾和中间都添加空格(blank),变成了:-C-A-T-,图中白色代表实体,黑色代表空格。
  2. 路径的搜索只能从左上方往右下方进行,不能低于当前位置
  3. 相同字符之间至少需要一个空格。比如:序列aa之间至少有一个“-”,否则就是错误的,因为不包含"-"的会被合并成一个a
  4. 非空字符不能被跳过。搜索过程中非空字符必须要对应一个输出
  5. 起点必须从第一个(空白)或第二个(第一个非空字符)开始,终点必须在最后一个(空白)或第二个(最后一个非空字符)结束。
前向算法

    这里用 α ( t , u ) = ∑ π ∈ V ( s , u ) ∏ i = 1 t y π i i \alpha(t,u)=\sum_{\pi\in V(s,u)}\prod_{i=1}^{t} y_{\pi_i}^{i} α(t,u)=πV(s,u)i=1tyπii表示 t t t时刻经过节点 u u u的路径的概率总和( u u u l ′ l^{'} l的索引,从1开始),特别的当 t = 1 t=1 t=1时:
α ( 1 , 1 ) = y b 1 α ( 1 , 2 ) = y l 1 α ( 1 , u ) = 0 ,    u > 2 \begin{aligned} & \alpha(1, 1)=y_b^1 \\ & \alpha(1, 2)=y_{l_1} \\ & \alpha(1, u)=0,\space\space u\gt2 \end{aligned} α(1,1)=yb1α(1,2)=yl1α(1,u)=0,  u>2
    其他时刻需要分情况考虑:

  1. t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为空白时,那么能够到达它的节点为 ( u , t − 1 ) (u, t-1) (u,t1) ( u − 1 , t − 1 ) (u-1, t-1) (u1,t1),可以表达为: α ( t , u ) = ( α ( t − 1 , u ) + α ( t − 1 , u − 1 ) ) ∗ y u t \alpha(t,u)=(\alpha(t-1,u)+\alpha(t-1,u-1))*y_{u}^t α(t,u)=(α(t1,u)+α(t1,u1))yut
  2. t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为非空字符且与前一个非空字符相同时,那么能够到达它的节点为 ( u , t − 1 ) (u, t-1) (u,t1) ( u − 1 , t − 1 ) (u-1, t-1) (u1,t1),可以表达为: α ( t , u ) = ( α ( t − 1 , u ) + α ( t − 1 , u − 1 ) ) ∗ y u t \alpha(t,u)=(\alpha(t-1,u)+\alpha(t-1,u-1))*y_{u}^t α(t,u)=(α(t1,u)+α(t1,u1))yut,与1一样;
  3. t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为非空字符且与前一个非空字符不相同时,那么能够到达它的节点为 ( u , t − 1 ) (u, t-1) (u,t1) ( u − 1 , t − 1 ) (u-1, t-1) (u1,t1) ( u − 2 , t − 1 ) (u-2,t-1) (u2,t1),可以表达为: α ( t , u ) = ( α ( t − 1 , u ) + α ( t − 1 , u − 1 ) + α ( t − 1 , u − 2 ) ) ∗ y u t \alpha(t,u)=(\alpha(t-1,u)+\alpha(t-1,u-1)+\alpha(t-1,u-2))*y_{u}^t α(t,u)=(α(t1,u)+α(t1,u1)+α(t1,u2))yut
        论文中用一下表述该概括这三种情况:
    α ( t , u ) = y u t ∑ i = f ( u ) u α ( t − 1 , i ) \begin{aligned} & \alpha(t,u)=y_u^t \sum_{i=f(u)}^{u}\alpha(t-1,i) \\ \end{aligned} α(t,u)=yuti=f(u)uα(t1,i)
    其中: f ( u ) = { u − 1 l ′ [ u ] = b l a n k   o r   l ′ [ u ] = l ′ [ u − 2 ] u − 2 otherwise f(u)= \begin{cases} u-1& \text{$l^{'}[u]=blank \space or \space l^{'}[u]=l^{'}[u-2]$}\\ u-2& \text{otherwise} \end{cases} f(u)={u1u2l[u]=blank or l[u]=l[u2]otherwise
        最后,总的损失(考虑最后一个是空格和非空两种情况, , ∣ l ′ ∣ ,|l^{'}| ,l表示label的长度)可以表示为:
    L ( S ) = − l n ( I ∣ x ) = − l n ( α ( T , ∣ l ′ ∣ ) + α ( T , ∣ l ′ ∣ − 1 ) ) L(S)=-ln(I|x)=-ln(\alpha(T,|l^{'}|)+\alpha(T,|l^{'}|-1)) L(S)=ln(Ix)=ln(α(T,l)+α(T,l1))
后向算法

    后向算法与前向算法一样,只是方向是反的。前向算法是从 t = 1 t=1 t=1 t = T t=T t=T,后向算法是 t = T t=T t=T t = 1 t=1 t=1
    这里用 β ( t , u ) = ∑ π ∈ V ( s , u ) ∏ i = t + 1 T y π i i \beta(t,u)=\sum_{\pi\in V(s,u)}\prod_{i=t+1}^{T} y_{\pi_i}^{i} β(t,u)=πV(s,u)i=t+1Tyπii表示 t t t时刻经过节点 u u u的路径的概率总和( u u u l ′ l^{'} l的索引,从1开始),特别的当 t = T t=T t=T时:
β ( T , ∣ l ′ ∣ ) = 1 β ( T , ∣ l ′ ∣ − 1 ) = 1 β ( T , u ) = 0 ,    u < ∣ l ′ ∣ − 2 \begin{aligned} & \beta(T, |l^{'}|)=1 \\ & \beta(T, |l^{'}|-1)=1 \\ & \beta(T, u)=0,\space\space u\lt|l^{'}|-2 \end{aligned} β(T,l)=1β(T,l1)=1β(T,u)=0,  u<l2
    其他时刻需要分情况考虑:

  1. t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为空白时,那么能够到达它的节点为 ( u , t + 1 ) (u, t+1) (u,t+1) ( u + 1 , t + 1 ) (u+1, t+1) (u+1,t+1),可以表达为: β ( t , u ) = ( β ( t + 1 , u ) + β ( t + 1 , u + 1 ) ) ∗ y u t + 1 \beta(t,u)=(\beta(t+1,u)+\beta(t+1,u+1))*y_{u}^{t+1} β(t,u)=(β(t+1,u)+β(t+1,u+1))yut+1
  2. t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为非空字符且与前一个非空字符相同时,那么能够到达它的节点为 ( u , t + 1 ) (u, t+1) (u,t+1) ( u + 1 , t + 1 ) (u+1, t+1) (u+1,t+1),可以表达为: β ( t , u ) = ( β ( t + 1 , u ) + β ( t + 1 , u + 1 ) ) ∗ y u t + 1 \beta(t,u)=(\beta(t+1,u)+\beta(t+1,u+1))*y_{u}^{t+1} β(t,u)=(β(t+1,u)+β(t+1,u+1))yut+1,与1一样;
  3. t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为非空字符且与前一个非空字符不相同时,那么能够到达它的节点为 ( u , t + 1 ) (u, t+1) (u,t+1) ( u + 1 , t + 1 ) (u+1, t+1) (u+1,t+1) ( u + 2 , t + 1 ) (u+2,t+1) (u+2,t+1),可以表达为: β ( t , u ) = ( β ( t + 1 , u ) + β ( t + 1 , u + 1 ) + β ( t + 1 , u + 2 ) ) ∗ y u t + 1 \beta(t,u)=(\beta(t+1,u)+\beta(t+1,u+1)+\beta(t+1,u+2))*y_{u}^{t+1} β(t,u)=(β(t+1,u)+β(t+1,u+1)+β(t+1,u+2))yut+1
        论文中用一下表述该概括这三种情况:
    β ( t , u ) = y u t + 1 ∑ i = f ( u ) u β ( t + 1 , i ) \begin{aligned} & \beta(t,u)=y_{u}^{t+1} \sum_{i=f(u)}^{u}\beta(t+1,i) \\ \end{aligned} β(t,u)=yut+1i=f(u)uβ(t+1,i)
    其中: f ( u ) = { u + 1 l ′ [ u ] = b l a n k   o r   l ′ [ u ] = l ′ [ u + 2 ] u + 2 otherwise f(u)= \begin{cases} u+1& \text{$l^{'}[u]=blank \space or \space l^{'}[u]=l^{'}[u+2]$}\\ u+2& \text{otherwise} \end{cases} f(u)={u+1u+2l[u]=blank or l[u]=l[u+2]otherwise
        最后,总的损失(考虑最后一个是空格和非空两种情况, , ∣ l ′ ∣ ,|l^{'}| ,l表示label的长度)可以表示为:
    L ( S ) = − l n ( I ∣ x ) = − l n ( β ( 1 , 1 ) + β ( 1 , 2 ) ) L(S)=-ln(I|x)=-ln(\beta(1,1)+\beta(1,2)) L(S)=ln(Ix)=ln(β(1,1)+β(1,2))
损失函数

    这里我们可以利用前向算法和后向算法来表示 t t t时刻通过节点 u u u的概率:
α ( t , u ) β ( t , u ) = ∑ π ∈ X ( t , u ) ∏ t = 1 T y π t t = ∑ π ∈ X ( t , u ) p ( π ∣ x ) \begin{aligned} \alpha(t,u)\beta(t,u) & =\sum_{\pi \in X(t, u)}\prod_{t=1}^{T}y_{\pi_t}^{t} \\ & =\sum_{\pi \in X(t, u)}p(\pi|x) \\ \end{aligned} α(t,u)β(t,u)=πX(t,u)t=1Tyπtt=πX(t,u)p(πx)
    之前我们只是表示了总的损失,那么 t t t时刻的损失如何表示了,论文中做了表述, t t t时刻的损失可以表示为:
KaTeX parse error: Undefined control sequence: \inX at position 58: …=-\ln(\sum_{\pi\̲i̲n̲X̲(t,u)}\prod_{t=…

梯度反向传播

CTC原理介绍_第3张图片
    如上图所示,我们要求损失关于 u k t u_k^t ukt的梯度。在求解之前,我们需要先做一些准备工作(用 z z z表示 I I I,论文是这样表达的):
∂ L ( z , x ) ∂ y k ′ t = ∂ ( − ln ⁡ p ( z ∣ x ) ) ∂ y k ′ t = − 1 p ( z ∣ x ) ∂ ∑ π ∈ B − 1 ( z ) p ( π ∣ x ) ∂ y k ′ t = − 1 p ( z ∣ x ) ( ∑ u ∈ B ( z , k ′ ) ∂ α ( t , u ) β ( t , u ) ∂ y k ′ t ) = − 1 p ( z ∣ x ) ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) y k ′ t \begin{aligned} \frac{\partial L(z,x)}{\partial y_{k^{'}}^t} & = \frac{\partial(-\ln p(z|x))}{\partial y_{k^{'}}^t} \\ & =-\frac{1}{p(z|x)} \frac{\partial \sum_{\pi \in B^{-1}(z)}p(\pi|x)}{\partial y_{k^{'}}^t} \\ & =-\frac{1}{p(z|x)} (\sum_{u \in B(z,k^{'})}\frac{\partial \alpha(t,u)\beta(t,u)}{\partial y_{k^{'}}^{t}}) \\ & =-\frac{1}{p(z|x)} \sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}} \\ \end{aligned} yktL(z,x)=ykt(lnp(zx))=p(zx)1yktπB1(z)p(πx)=p(zx)1(uB(z,k)yktα(t,u)β(t,u))=p(zx)1uB(z,k)yktα(t,u)β(t,u)
∂ y k ′ t ∂ u k t = y k ′ t ( δ k k ′ − y k t ) \begin{aligned} \frac{\partial y_{k^{'}}^t}{\partial u_k^t}=y_{k^{'}}^t(\delta_{kk^{'}}-y_k^t) \end{aligned} uktykt=ykt(δkkykt)
这个是softmax的求导,具体过程这里就不累赘了。
其中: δ k k ′ = { 1 k = k ′ 0 otherwise \delta_{kk^{'}}= \begin{cases} 1& \text{$k=k^{'}$}\\ 0& \text{otherwise} \end{cases} δkk={10k=kotherwise
    如上图所示,如果我们要求损失关于 u k t u_k^t ukt的梯度,则:
∂ L ( z , x ) ∂ u k t = ∑ k ′ ∂ L ( z , x ) ∂ y k ′ t ∂ y k ′ t ∂ u k t = − ∑ k ′ 1 p ( z ∣ x ) ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) y k ′ t ) y k ′ t ( δ k k ′ − y k t ) = − 1 p ( z ∣ x ) ∑ k ′ ( ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) y k ′ t ) y k ′ t ( δ k k ′ − y k t ) ) = − 1 p ( z ∣ x ) ( ∑ k ′ = k ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) y k ′ t ) y k ′ t ( 1 − y k t ) + ∑ k ′ ! = k ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) y k ′ t ) y k ′ t ( 0 − y k t ) ) = − 1 p ( z ∣ x ) ( ∑ k ′ = k ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) ) ( 1 − y k t ) − ∑ k ′ ! = k ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) ) y k t ) = − 1 p ( z ∣ x ) ( ∑ k ′ = k ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) − ∑ k ′ = k ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) ) y k t ) − ∑ k ′ ! = k ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) ) y k t ) = − 1 p ( z ∣ x ) ( ∑ u ∈ B ( z , k ) α ( t , u ) β ( t , u ) − ∑ k ′ ( ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) ) y k t ) = 1 p ( z ∣ x ) ( − ∑ u ∈ B ( z , k ) α ( t , u ) β ( t , u ) + y k t ∑ u ∈ B ( z , k ′ ) α ( t , u ) β ( t , u ) ) = 1 p ( z ∣ x ) ( − ∑ u ∈ B ( z , k ) α ( t , u ) β ( t , u ) + y k t p ( z ∣ x ) ) = y k t − 1 p ( z ∣ x ) ∑ u ∈ B ( z , k ) α ( t , u ) β ( t , u ) \begin{aligned} \frac{\partial L(z,x)}{\partial u_k^t} & = \sum_{k^{'}} \frac{\partial L(z,x)}{\partial y_{k^{'}}^t}\frac{\partial y_{k^{'}}^t}{\partial u_k^t} \\ & = -\sum_{k^{'}}\frac{1}{p(z|x)} (\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(\delta_{kk^{'}}-y_k^t) \\ & = -\frac{1}{p(z|x)}\sum_{k^{'}}((\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(\delta_{kk^{'}}-y_k^t)) \\ & = -\frac{1}{p(z|x)}(\sum_{k^{'}=k}(\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(1-y_k^t) +\sum_{k^{'}!=k}(\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(0-y_k^t)) \\ & = -\frac{1}{p(z|x)}(\sum_{k^{'}=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))(1-y_k^t) -\sum_{k^{'}!=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) \\ & = -\frac{1}{p(z|x)}(\sum_{k^{'}=k} \sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u) - \sum_{k^{'}=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) - \sum_{k^{'}!=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) \\ & = -\frac{1}{p(z|x)}( \sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) - \sum_{k^{'}}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) \\ & = \frac{1}{p(z|x)}(-\sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) + y_k^t \sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u)) \\ & = \frac{1}{p(z|x)}(-\sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) + y_k^t p(z|x)) \\ & = y_k^t - \frac{1}{p(z|x)}\sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) \\ \end{aligned} uktL(z,x)=kyktL(z,x)uktykt=kp(zx)1(uB(z,k)yktα(t,u)β(t,u))ykt(δkkykt)=p(zx)1k((uB(z,k)yktα(t,u)β(t,u))ykt(δkkykt))=p(zx)1(k=k(uB(z,k)yktα(t,u)β(t,u))ykt(1ykt)+k!=k(uB(z,k)yktα(t,u)β(t,u))ykt(0ykt))=p(zx)1(k=k(uB(z,k)α(t,u)β(t,u))(1ykt)k!=k(uB(z,k)α(t,u)β(t,u))ykt)=p(zx)1(k=kuB(z,k)α(t,u)β(t,u)k=k(uB(z,k)α(t,u)β(t,u))ykt)k!=k(uB(z,k)α(t,u)β(t,u))ykt)=p(zx)1(uB(z,k)α(t,u)β(t,u)k(uB(z,k)α(t,u)β(t,u))ykt)=p(zx)1(uB(z,k)α(t,u)β(t,u)+yktuB(z,k)α(t,u)β(t,u))=p(zx)1(uB(z,k)α(t,u)β(t,u)+yktp(zx))=yktp(zx)1uB(z,k)α(t,u)β(t,u)
    至此,CTC的反向传播推导完成。CTC的解码过程后续会给出,并且附带原生Python代码实现,近期推出!

你可能感兴趣的:(算法原理,深度学习,OCR,CTC原理,CTC推导)