CTC损失函数及其实现[1]

CTC损失函数及其实现[1]

    • 简介
    • 原理
    • 实现
    • 参考

本文主要讲解了CTC损失函数的主要原理以及介绍了目前该损失函数的各种实现.

简介

显示中许多序列学习任务需要从含噪声,并且未分割的输入数据中预测出标签序列. 例如,语音识别,需要将声学信号转录成单词。RNNs似乎是一种适合这种任务的强大的序列学习器,但是,由于需要预分割的训练数据,以及需要后处理,将RNNs的输出转换为标签序列,使得该方法的应用受到限制。参考文献[1]所提出的方法,就是为了解决上述存在的两个问题,即:直接训练RNNs来直接标注未分割的数据。

原理

核心思想:

The basic idea is to interpret the network outputs as a probability distribution overall possible label sequences, conditioned on a given input sequence. Given this distribution, an objective function can be derived that directly maximises the probabilities of the correct labellings. Since the objective function is differentiable, the network can then be trained with standard backpropagation through time(Werbos, 1990).

术语定义:

In what follows, we refer to the task of labelling unsegmented data sequences as temporal classification (Kadous, 2002), and to our use of RNNs for this purpose as connectionist temporal classification(CTC).
By contrast, we refer to the independent labelling of each time-step, or frame, of the input sequence as
framewise classification.

Temporal Classification
CTC损失函数及其实现[1]_第1张图片

Label Error Rate
CTC损失函数及其实现[1]_第2张图片
Connectionist Temporal Classification
CTC损失函数及其实现[1]_第3张图片
From Network Outputs to Labellings
CTC损失函数及其实现[1]_第4张图片
CTC损失函数及其实现[1]_第5张图片

实现

由于这个损失函数确实有很广泛的应用,目前有很多机构都实现了这个函数,下面主要贴上各种开源实现.
1.百度
https://github.com/baidu-research/warp-ctc
2. Pytorch 1.0
https://pytorch.org/docs/master/nn.html#torch.nn.CTCLoss
3. SeanNaren
https://github.com/SeanNaren/warp-ctc

无论何种实现,使用的方式都很类似,参数都是一样的.

where T represents the number of timesteps in the input to CTC, L represents the length of the labels for each example, and A represents the alphabet size.

参考

[1] A. Graves et al.: Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks: https://www.cs.toronto.edu/~graves/icml_2006.pdf

你可能感兴趣的:(pytorch,深度学习,OCR)