CTCloss 详解
简介
在ocr任务与机器翻译中,输入与输出GT文本很难在单词上对齐,在预处理的时候对齐是非常困难的,但是如果不对齐而直接训练模型的话,由于字符距离的不同,导致模型很难收敛.
CTC(Connectionist Temporal Classification)避免了输入与输出手动对齐,适合OCR或语音这样的序列应用;
建模
给定输入序列,以及对应的标签数据.不一定相等 我们的工作是找到一个X到Y的映射.这种对时序数据进行分类算法叫做Temporal Classification。
对比传统分类方法,时序分类有一下困难:
- X和Y的长度都是变化且不相等的.
- 对于一个端到端的模型,我们并不想手动设计X和Y之间的对齐.
CTC提供了解决方案,对于一个给定的输入序列 ,CTC给出所有可能的 Y 的输出分布。根据这个分布,我们可以输出最可能的结果或者给出某个输出的概率。
loss:给定输入序列X,我们希望最大化Y的后验概率 ,应该是可导的,这样我们就能执行梯度下降算法进行优化.
infer:
1.1 对齐
在ocr任务中,输入X是一张含有"CAT"的图片,输出Y是文本[C,A,T]
最原始的对齐方式将X分割成若干个时间片,每一个时间片得到一个字符的输出,然后合并连续重复出现的字符.
然而这样做有两个缺点:
- 几乎不可能将 X 的每个时间片都和输出Y对应上,例如OCR中字符的间隔,语音识别中的停顿;
- 不能处理有连续重复字符出现的情况,例如单词“APPLE”,按照上面的算法,输出的是“APLE”而非“APPLE”。
为了解决上面的问题,CTC引入了空白字符,
CTC的对齐涉及去除重复字母和去除 空白字符 两部分.
其规则:
- 连续相同的字符做去重
- 去重空白字符
比如,对于长度为10的输入序列,以下RNN输出序列都可以映射为: apple
- _aappp_ple
- ap_p_|_ e
- _ _ app_ple_
最后要计算P(Y|X),可以累加其对应的全部输出(全部输出为apple)的路径概率之和.
因此,在训练阶段,我们要对GT进行标签扩充.其做法是:
头尾加空白符,并在GT中的每一个字符间插入空白符;
用l表示最终标签,l’表示扩展后的形式,
则由2|l| + 1 = 2|l’|,比如:l=apple => l’=_a_p_p_l_e_
如图所示:
1.2 路径搜索与动态规划
上图的路径搜索中:
定义:
- 为第t个时刻,gt字符串的第s个字符的路径前向概率.
- 为预测矩阵中第t时刻是第s个字符的概率.
- 为输入序列x,输出为l的概率,我们要最大化其概率
(1) 如果为空白符.则只能由前一个空白符或者其GT中该字符为上一时刻得到(因为我们是隔一个字符插入空白符,当前字符是空白的话,如果前一个也不是s的字符,就会错过GT中s字符,导致最后的path没法解析到GT,所以要么最多连续两个空白格,要么是前一个已经出现s字符,当前可以为空白)
所以这种情况下其概率为:
(2)如果不为空白符,那么该点的前向概率之和可以通过以下路径得到
- :s为当前gt的第s个字符,然后前一个为空白字符的概率
- :当前字符s连续出现的概率
- :前一个字符GT的上一个字符,当前步是GT的S个字符,代表s-1,s之间没有空白符,没有连续的字符的概率(因为每个字符都隔了一个空白符)
所以这种情况下前向概率为:
初始化值:
(代表了T时刻可以从空白符出发,也可以从gt的第0个字符开始)
最后我们需要计算
两个前向概率之和便得到前向概率之和.如图的右下角两个位置概率之和.
利用前向概率计算ctc 的loss
即 等于最小化对数域.
所以loss的值为:
简化计算
我们看到在计算过程中我们发现了大量的连乘。由于每一个数字都是浮点数,那么这样连乘下去,最终数字有可能非常小而导致underflow。所以我们要将这个计算过程转到对数域上。这样我们就将其中的乘法转变成了加法。
由于最后计算loss为-ln(P)
所以在计算前向概率的时候可以直接计算log(p)的值..
logsumexp的优化
如果我们有N个概率,,我们想求其对数域之和:
如果很大或很小,朴素的直接计算会上溢出或下溢出,从而导致严重问题。举个例子,对于[0,1,0]直接计算是可行的,我们可以得到1.55。但对于[1000,1001,1000]却并不可行,我们会得到inf;对于[-1000,-999,-1000],还是不行,我们会得到-inf.
解决方法:
一般情况下,a取N个值中的最大值;
这可以保证指数最大不会超过0,于是你就不会上溢出。即便剩余的部分下溢出了,你也能得到一个合理的值。
证明:
CTC 概率图前向概率:
代表了第t时刻第s个gt字符(经过补充空白符)的概率
在torch的ctc 前馈过程中,计算的log前向概率值的矩阵,(用以进行loss back),我们看到其核心:
- 每一个baych ,通过两层循环(T,S)动态规划计算前向概率Log值. 在计算的同时将同样需要计算的作为la1 ,la2,然后判断当前s的字符来决定第三个加项是否为
- 同时,因为转换到了对数域,也避免了数变小与溢出的问题,其项变成了
log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];
源码分析:
使用了pytorch中ctcloss 的源码
//pytorch/aten/src/ATen/native/LossCTC.cpp
//获取填充blank后的target指定位置的值,用来判断是
static inline int64_t get_target_prime(target_t *target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK)
{
if (idx % 2 == 0)
{
return BLANK;
}
else
{
return target[offset + stride * (idx / 2)];
}
}
//ctc_loss_cpu_template部分核心代码
//前向概率[B,T,N]
Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1}, log_probs.options());
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
//[B,T,N]
auto lpp = log_probs.permute({1, 0, 2});
auto log_probs_a_global = lpp.accessor();
auto log_alpha_a_global = log_alpha.accessor();
auto targets_data = targets.data_ptr();
auto neg_log_likelihood_a = neg_log_likelihood.accessor();
// alpha calculation for the first row, the three equations for alpha_1 above eq (6)
// first the default
log_alpha.narrow(1, 0, 1).fill_(neginf);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (int64_t b = start; b < end; b++)
{
//每个batch
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
auto log_probs_a = log_probs_a_global[b];
auto log_alpha_a = log_alpha_a_global[b];
int64_t tg_batch_offset = tg_batch_offsets[b];
// the first two items of alpha_t above eq (6)
//初始化前向概率[t0][s0]
log_alpha_a[0][0] = log_probs_a[0][BLANK];
if (target_length > 0)
//[t0][s1]等于序列中第一个字符的概率
log_alpha_a[0][1] = log_probs_a[0][get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
// now the loop over the inputs
for (int64_t t = 1; t < input_length; t++)
{
for (int64_t s = 0; s < 2 * target_length + 1; s++)
{
//对于每一个s,计算其概率
//获取第s个字符是什么
auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
// this loop over s could be parallel/vectorized, too, but the required items are one index apart
// alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
// for the cuda implementation, that gave a speed boost.
// This is eq (6) and (7), la1,2,3 are the three summands. We keep track of the maximum for the logsumexp calculation.
scalar_t la1 = log_alpha_a[t - 1][s];
scalar_t lamax = la1;
scalar_t la2, la3;
if (s > 0)
{
la2 = log_alpha_a[t - 1][s - 1];
if (la2 > lamax)
lamax = la2;
}
else
{
la2 = neginf;
}
if ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s - 2, BLANK) !=
current_target_prime))
{
//第s个字符不是空且不等于s-2(即不连续的时候),即动态转移方程的第二个式子
la3 = log_alpha_a[t - 1][s - 2];
if (la3 > lamax)
lamax = la3;
}
else
{
//s为空或者连续,第三项不用加
la3 = neginf;
}
//添加概率最大项按前一个[t-1][s]
if (lamax == neginf) // cannot do neginf-neginf
lamax = 0;
//计算此时的
// this is the assignment of eq (6)
log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];
}
}
// the likelihood is the the sum of the last two alphas, eq (8), the loss is the negative log likelihood
if (target_length == 0)
{
// if the target is empty then there is no preceding BLANK state and hence there is no path to merge
neg_log_likelihood_a[b] = -log_alpha_a[input_length - 1][0];
}
else
{
scalar_t l1 = log_alpha_a[input_length - 1][target_length * 2];
scalar_t l2 = log_alpha_a[input_length - 1][target_length * 2 - 1];
//取两条路的概率之和
scalar_t m = std::max(l1, l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
neg_log_likelihood_a[b] = -log_likelihood;
}
}
});
提供一个python 版本的numpy ctc的代码方便理解
import numpy as np
ninf = -np.float('inf')
def _logsumexp(a, b):
'''
np.log(np.exp(a) + np.exp(b))
'''
if a < b:
a, b = b, a
if b == ninf:
return a
else:
return a + np.log(1 + np.exp(b - a))
def logsumexp(*args):
'''
from scipy.special import logsumexp
logsumexp(args)
'''
res = args[0]
for e in args[1:]:
res = _logsumexp(res, e)
return res
class CTC:
def __init__(self):
pass
def forward(self):
pass
def alpha(self, log_y, labels):
##alpha 为前向概率
T, V = log_y.shape
L = len(labels)
log_alpha = np.ones([T, L]) * ninf
# init
## 初始化动态规划
log_alpha[0, 0] = log_y[0, labels[0]]
log_alpha[0, 1] = log_y[0, labels[1]]
##计算每一步,每个GT的前向概率
for t in range(1, T):
for i in range(L):
s = labels[i]
a = log_alpha[t - 1, i]
if i - 1 >= 0:
a = logsumexp(a, log_alpha[t - 1, i - 1])
##如果当前不是空白符,得加前两步的状态
if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
a = logsumexp(a, log_alpha[t - 1, i - 2])
log_alpha[t, i] = a + log_y[t, s]
return log_alpha
def beta(self, log_y, labels):
##计算后向概率
T, V = log_y.shape
L = len(labels)
log_beta = np.ones([T, L]) * ninf
# init
log_beta[-1, -1] = log_y[-1, labels[-1]]
log_beta[-1, -2] = log_y[-1, labels[-2]]
for t in range(T - 2, -1, -1):
for i in range(L):
s = labels[i]
a = log_beta[t + 1, i]
if i + 1 < L:
a = logsumexp(a, log_beta[t + 1, i + 1])
if i + 2 < L and s != 0 and s != labels[i + 2]:
a = logsumexp(a, log_beta[t + 1, i + 2])
log_beta[t, i] = a + log_y[t, s]
return log_beta
def backward(selflog_y, labels):
T, V = log_y.shape
L = len(labels)
log_alpha = self.alpha(log_y, labels)
log_beta = self.beta(log_y, labels)
log_p = logsumexp(log_alpha[-1, -1], log_alpha[-1, -2])
##任意时刻的
log_grad = np.ones([T, V]) * ninf
for t in range(T):
for s in range(V):
lab = [i for i, c in enumerate(labels) if c == s]
for i in lab:
log_grad[t, s] = logsumexp(log_grad[t, s],
log_alpha[t, i] + log_beta[t, i])
log_grad[t, s] -= 2 * log_y[t, s]
log_grad -= log_p
return log_grad
def predict(self):
pass
def ctc_prefix(self):
pass
def ctc_beamsearch(self):
pass
def alpha_vanilla(self, y, labels):
T, V = y.shape # T,time step, V: probs
L = len(labels) # label length
alpha = np.zeros([T, L])
# init
alpha[0, 0] = y[0, labels[0]]
alpha[0, 1] = y[0, labels[1]]
for t in range(1, T):
for i in range(L):
s = labels[i]
a = alpha[t - 1, i]
if i - 1 >= 0:
a += alpha[t - 1, i - 1]
if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
a += alpha[t - 1, i - 2]
alpha[t, i] = a * y[t, s]
return alpha
def beta_vanilla(self, y, labels):
##原始版计算前向概率,没在对数域中计算
T, V = y.shape
L = len(labels)
beta = np.zeros([T, L])
# init
beta[-1, -1] = y[-1, labels[-1]]
beta[-1, -2] = y[-1, labels[-2]]
for t in range(T - 2, -1, -1):
for i in range(L):
s = labels[i]
a = beta[t + 1, i]
if i + 1 < L:
a += beta[t + 1, i + 1]
if i + 2 < L and s != 0 and s != labels[i + 2]:
a += beta[t + 1, i + 2]
beta[t, i] = a * y[t, s]
return beta
def gradient(self, y, labels):
T, V = y.shape
L = len(labels)
alpha = self.alpha_vanilla(y, labels)
beta = self.beta(y, labels)
p = alpha[-1, -1] + alpha[-1, -2]
grad = np.zeros([T, V])
for t in range(T):
for s in range(V):
lab = [i for i, c in enumerate(labels) if c == s]
for i in lab:
grad[t, s] += alpha[t, i] * beta[t, i]
grad[t, s] /= y[t, s] ** 2
grad /= p
return grad
def check_grad(y, labels, w=-1, v=-1, toleration=1e-3):
grad_1 = gradient(y, labels)[w, v]
delta = 1e-10
original = y[w, v]
y[w, v] = original + delta
alpha = forward(y, labels)
log_p1 = np.log(alpha[-1, -1] + alpha[-1, -2])
y[w, v] = original - delta
alpha = forward(y, labels)
log_p2 = np.log(alpha[-1, -1] + alpha[-1, -2])
y[w, v] = original
grad_2 = (log_p1 - log_p2) / (2 * delta)
if np.abs(grad_1 - grad_2) > toleration:
print('[%d, %d]:%.2e' % (w, v, np.abs(grad_1 - grad_2)))
def remove_blank(labels, blank=0):
new_labels = []
# combine duplicate
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# remove blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def greedy_decode(y, blank=0):
raw_rs = np.argmax(y, axis=1)
rs = remove_blank(raw_rs, blank)
return raw_rs, rs
def beam_decode(y, beam_size=10):
T, V = y.shape
log_y = np.log(y)
beam = [([], 0)]
for t in range(T): # for every timestep
new_beam = []
for prefix, score in beam:
for i in range(V): # for every state
new_prefix = prefix + [i]
new_score = score + log_y[t, i]
new_beam.append((new_prefix, new_score))
# top beam_size
new_beam.sort(key=lambda x: x[1], reverse=True)
beam = new_beam[:beam_size]
return beam
def prefix_beam_decode(y, beam_size=10, blank=0):
T, V = y.shape
log_y = np.log(y)
beam = [(tuple(), (0, ninf))] # blank, non-blank
for t in range(T): # for every timestep
new_beam = defaultdict(lambda : (ninf, ninf))
for prefix, (p_b, p_nb) in beam:
for i in range(V): # for every state
p = log_y[t, i]
if i == blank: # propose a blank
new_p_b, new_p_nb = new_beam[prefix]
new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
continue
else: # extend with non-blank
end_t = prefix[-1] if prefix else None
# exntend current prefix
new_prefix = prefix + (i,)
new_p_b, new_p_nb = new_beam[new_prefix]
if i != end_t:
new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
else:
new_p_nb = logsumexp(new_p_nb, p_b + p)
new_beam[new_prefix] = (new_p_b, new_p_nb)
# keep current prefix
if i == end_t:
new_p_b, new_p_nb = new_beam[prefix]
new_p_nb = logsumexp(new_p_nb, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
# top beam_size
beam = sorted(new_beam.items(), key=lambda x : logsumexp(*x[1]), reverse=True)
beam = beam[:beam_size]
return beam