CTC 讲解

Connectionist Temporal Classification


https://sunnycat2013.gitbooks.io/blogs/content/posts/ctc/learning-ctc.html


因为最近做了一些用连续标签做文字识别标签任务的工作,对 ctc 有了一些了解,在此记录一下。

在学习 CTC 的时候,也看了不少博客,但是我觉得讲的最好的还是原论文 Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks 解释的最清楚。
对于没接触过这个概念的人,可能加一些例子会更好理解一些。
我就来加一些例子。

背景知识

用实例来说,我们在做 ocr 工作时,我们希望给一行文字的图片让机器识别出来这个图片里面的文字。
语音识别任务中,给了一段语音片段,我们希望能把这段语音识别成可编辑的文字。
但是,在对每个片段进行分类模型训练之前,需要对每个训练样本进行切割标注。
这是项非常繁琐的工作非常不利于模型的训练。

下面是个对文字进行标注的工具,大家可以看一下。如果我们在做文字识别工作时,对每个文字都要明确标出这个字在图片中的位置、高、宽,这将会是一个多么巨大的工作量。

CTC 讲解_第1张图片

RNNs 在序列学习任务中有着优越的性能,但是它也有一些缺点。
如,在上面描述的那种对输入模型的数据需要预处理的缺点。同时,对 RNNs 的序列也需要一定的整合才能得到最终的预测序列。
而 CTC 解决了对输入序列的单个词的切分和对输入序列的整合工作。

RNN 的输出

输入序列的切分与标注上面已经举了一个例子,现在举一个输出序列整合的例子。
我们现在有一个图片的输入 hello
假设这个图片中每个红都作为 RNN 的一步输入,那么(如果这个模型训练的还不错的话)它的输出序列应该是 hheelloo
但是,我们知道 RNN 每一步的输出其实都是一组概率分布,p(l|x), l \in Alphebatp(lx),lAlphebat
如,对第一个矩形框的输出概率可能是 p(l = 'h' | x) = 0.5, p(l = 'm' | x) = 0.3 \cdotsp(l=hx)=0.5,p(l=mx)=0.3

时序分类(Temporal Classification)

先给几个定义。

符号 解释
LL 字母表 Alphebat
(R^m)^\ast(Rm) mm 表示输入数据的一个“宽度”,如我们输入的是一个图片时,宽度可以是一个定值。 \ast 表示这个串的长度不定 \in [0, +\infty)[0,+),如输入图片的长度是未知长度。
\chiχ \chi = (R^m)^\astχ=(Rm) 表示输入数据空间。
ZZ Z = L^\astZ=L 表示由字母表排列而成的标签集合,我们可以理解成单词表。
D_{\chi \times Z}Dχ×Z 真实数据空间。
z z = (z_1, z_2, \cdots , z_U)=(z1,z2,,zU) 是 ZZ 中的一个样本。
x x = (x_1, x_2, \cdots , x_T), U \leq T=(x1,x2,,xT),UT,表示一个样本输入数据的输入序列。如一个定高图片每一列像素可以认为是一个 x_ixi
SS S \subset D_{\chi \times Z}SDχ×Z 训练样本集,这个集合中的每个样本都是一个 (x, z) 组合
S'S S' \subset D_{\chi \times Z}, S' \bigcap S = \emptysetSDχ×Z,SS=
hh 时序分类器。

由上面的定义,我们可以看出,因为输入和输出和长度未必相等,所以没有办法事先把这两种数据对齐。

目标

时序分类的目标就是学习h: \chi \longmapsto Zh:χZ

损失函数

用于 CTC 的损失函数是 Lebal Error Rate(LER)。 这里我们需要知道“最小编辑距离(Edit Distance, ED)”这个概念,在 CTC 的损失函数就用到了。

在学习算法的时候,ED 算是一个比较经典的动态规划问题,但在实际工作中其实很少用到这类算法。 所以第一次知道这个算法能用在这里,我还是挺开心的。

损失函数定义如下:LER(h, S') = \frac{1}{|S'|}\sum_{(x,z) \in S'}\frac{ED(h(x), z)}{|z|}LER(h,S)=S1(x,z)SzED(h(x),z)用编辑距离(ED)来衡量文字串的预测情况还是一件蛮符合直观理解的事情。

连接的时序分类(Connectionist Temporal Classification)

写了半天终于到正题了,下面开始讲 CTC! CTC 网络的 softmax 输出层输出的类别有 |L| + 1L+1 种,因为有一个分隔符,比如说是空格。

这个分隔符其实蛮重要的,它可以很好地区分一个输出序列串中,哪些子串是属于同一个文字的图片区域的输出结果。

一个输出序列的概率

首先,我们来看一个实例:输入 x 的长度是 T,每一个帧的维度是 m。模型的输出的长度也是 T,每一帧的维度是 n。其中 m n 可以相同也可以不同。用数学定义我们的这个模型就是:

y = N_{w}(x), N_w: (R^m)^T \longmapsto (R^n)^Ty=Nw(x),Nw:(Rm)T(Rn)T这里,我们引入几个新的概念:

  • y_{k}^tykt 表示,在时间帧为 t 的时候,模型的第 k 个输出值。 我们可以理解 y_{k}^tykt 为,模型认为这次输入的 x_txt 被认为是字母表 L'L 中第 k 个字母的概率。

  • \piπ 表示一个输出序列的组合,如 y_2^1 y_{20}^2 y_1^3 y_5^4 \cdotsy21y202y13y54 那么每组输入,对应的输出都有 (R^n)^T(Rn)T 种可能的字母排列,我们用 \piπ 表示其中一种排列。论文中称这种排列为 path

  • p(\pi | x)p(πx) 一种输出组合的概率。公式如下:p(\pi | x) = \prod y_{\pi_t}^t, \forall_{t = 1}^{T}\pi \in {L'}^T.p(πx)=yπtt,t=1TπLT.

如,hello 的一种可能的输出 hheelloo 的概率可以表示为 p('hheelloo' | hello.png) = y_8^1\ast y_8^2 \ast y_{5}^3 \ast y_{5}^4 \ast y_{12}^5 \ast y_{12}^6 \ast y_{15}^7 \ast y_{15}^8p(hheelloohello.png)=y81y82y53y54y125y126y157y158

一种输出序列的规整

同一种输入对应的多种输出可能会有多种形式。 如 hello 的输出可能是 hheel-loo 也可能是 hh-ee-l-l-oo 等。这里的 - 表示空格。 原论文处理这种情况的规则非常简单 We do this by simply removing all blanks and repeated labels from the pathsB(a-ab-) = B(-aa--abb) = aabB(aab)=B(aaabb)=aab一句话:把空格和连续重复的字母去掉。 那么B(hheel-loo) = helloB(hheelloo)=hello

B(hh-ee-l-l-oo) = helloB(hheelloo)=hello

则模型预测标签为:l = B(\pi), |l| \leq Tl=B(π),lT

预测标签的概率

我们可以看到预测标签有很多备选的输出序列,所以预测标签 ll 的概率公式:p(l|x) = \sum_{\pi \in B^{-1}(l)}p(\pi|x).p(lx)=πB1(l)p(πx).其中,B^{-1}(l)B1(l) 是输出序列规整函数 B(\pi)B(π) 的反函数。 如,p(l = 'hello'| x) = p(\pi = 'hheel-loo'|x) + p(\pi = 'hh-ee-l-l-oo' | x) + \cdotsp(l=hellox)=p(π=hheelloox)+p(π=hheelloox)+

讲到这里,大家就应该明白 CTC 是怎么工作的了。当然还有很多为了实现而做的工作,有时间再接着写吧。


ctc evaluation 中有关 loss 的含义

ctc_decode

参考

tensorflow ctc_ops

你可能感兴趣的:(机器学习)