ctcloss函数主要用在没有事先对齐的序列化数据训练上,比如语音识别,ocr识别等,主要的优点是可以对没有对齐的数据进行自动对齐。
L = a , o , e , i , u , b , p , m , f , ⋯ L={a,o,e,i,u,b,p,m,f,\cdots} L=a,o,e,i,u,b,p,m,f,⋯ 表示所有字符的集合。
π = ( π 1 , π 2 , ⋯ , π T ) , π i ε L π=(π_1,π_2,\cdots,π_T),π_i\varepsilon L π=(π1,π2,⋯,πT),πiεL 表示一条由L中元素组成的长度为T的路径,表示模型的输出序列。
l = ( l 1 , l 2 , ⋯ , l m ) , l i ε L l = (l_1,l_2,\cdots,l_m),l_i\varepsilon L l=(l1,l2,⋯,lm),liεL, 表示一条由L中元素组成的长度为m的路径,即为ground truth标签。
y k t ( k = 1 , 2 ⋯ , n , t = 1 , 2 ⋯ , T ) y^t_k(k=1,2\cdots,n,t=1,2\cdots,T) ykt(k=1,2⋯,n,t=1,2⋯,T):表示在t时刻,输出为k的概率,举个简单的例子:当输出的序列为(a-ab-)时, y a 3 y^3_a ya3代表了在第3步输出的字母为a的概率;
p ( π ∣ x ) p(π|x) p(π∣x):代表了给定输入x下,输出路径为 π π π的概率。
由于假设在每一个时间步输出的label的概率都是相互独立的,那么 p ( π ∣ x ) p(π|x) p(π∣x)用公式来表示为 p ( π ∣ x ) = ∏ t = 1 T y k t p(π|x)=\prod_{t=1}^{T}{y^t_k} p(π∣x)=∏t=1Tykt,可以理解为每一个时间步输出路径 $ π$ 的相应label的概率的乘积。
p ( π ∣ x ) = p ( π ∣ y = N w ( x ) ) = p ( π ∣ y ) = ∏ t = 1 T y k t p(π|x)=p(π|y=N_w(x))=p(π|y)=\prod_{t=1}^{T}{y^t_k} p(π∣x)=p(π∣y=Nw(x))=p(π∣y)=∏t=1Tykt ,y是经过softmax层得到的后验概率, N w ( x ) N_w(x) Nw(x)可理解为DNN 模型对输入提取feature后输出概率的一种变换。
p ( l ∣ x ) p(l|x) p(l∣x):代表给定输入x,输出为序列$l $的概率
因此输出的序列为 $l $ 的概率可以表示为所有输出的路径 π 映射后的序列为 l 的概率之和,用公式表示为
p ( l ∣ y = N w ( x ) ) = p ( l ∣ x ) = ∑ p ( π ∣ x ) p(l|y=N_w(x))=p(l|x)=\sum{p(π|x)} p(l∣y=Nw(x))=p(l∣x)=∑p(π∣x) subject to. { π ∣ B ( π ) = l {π|B(π)=l} π∣B(π)=l} ,其中{ π ∣ B ( π ) = l {π|B(π)=l} π∣B(π)=l}表示所有输出的路径π做压缩映射变换后等于$l $的路径集合。
总结:给定样本x后输出正确label 为$l $的概率的计算公式如下:
p ( l ∣ x ) = p ( l ∣ y = N w ( x ) ) = ∑ p ( π ∣ x ) = ∑ ∏ t = 1 T y k t p(l|x)=p(l|y=N_w(x))=\sum{p(π|x)}=\sum\prod_{t=1}^{T}{y^t_k} p(l∣x)=p(l∣y=Nw(x))=∑p(π∣x)=∑∏t=1Tykt
ctcloss表达式为
ctcloss = − l n p ( l ∣ x ) -lnp(l|x) −lnp(l∣x)
ctcloss就是计算一个输入序列x,模型预测序列中所有能映射到标签label的输出序列概率总和的负ln
而连乘再求和的计算可以采用 前置项* y k t y^t_k ykt*后置项 计算。
然后前置项,后置项通过动态规划即可求解。
推荐参考文章
https://blog.csdn.net/huangyiping12345/article/details/102668605
https://blog.csdn.net/luodongri/article/details/77005948
https://www.jianshu.com/p/0cca89f64987
pytorch内置nn.CTCLoss方法详解
https://zhuanlan.zhihu.com/p/67415439
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
我们在crnn+ctc训练文字识别项目时,
log_probs:模型输出张量shape为(T, B, C) ,其中T是模型输出时图像的宽度,一般称为input_length也即输出序列长度,此值是受模型输入时图像的width大小所影响,B是batch_size大小,C是包括空白标签的字符集的长度。
input_lengths:张量shape为(B, ) 常用preds_size = torch.IntTensor([preds.size(0)] * batch_size)得到此张量,preds.size(0)就是输入序列长度。
targets:张量shape(sum(target_lengths), ),每个元素是batch_size个text标签中的字符依次对应的索引,如前10个元素是第一个text标签的字符的索引,依次6个元素是第二个标签的字符索引。这里的标签不包括空白标签。张量shape也可以为(B, S),S则为标签长度,意味着需要将batch个label的one-hot标量组装成相同长度。
**targets_lengths:张量shape为(B,),每个元素是每个样本的text标签长度,标签长度可能各不相同。**每个元素分别指定了按序取targets多少个元素来表示一个text标签:
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
length = []
result = []
decode_flag = True if type(text[0])==bytes else False
for item in text:
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
for char in item:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
# ctc使用时常常看到这样的用法:
preds = crnn(image) # preds shape为(T, B, C)
# Tensor for argument #2 'targets' is on CPU, but expected it to be on GPU (while #checking arguments for ctc_loss_gpu)
# https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-517516922
# https://github.com/ypwhs/captcha_break/issues/29
# https://github.com/pytorch/pytorch/issues/14401#issuecomment-447710313
# https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-524738484
preds = preds.to(torch.float64) # 此行是解决pytorch ctcloss的内部bug
preds = preds.to(device)
batch_size = image.size(0)
text, length = converter.encode(label) # encode就是上面这个函数
text = text.to(device)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
cost = criterion(preds, text, preds_size, length) / batch_size # criterion就是ctc_loss
注意:CTCLoss 要求 input_length >= 2 * target_length + 1, input_length是模型输出时的序列长度,target_length是label的字符长度
pytorch ctcloss 使用时常常出现的错误
1, Tensor for argument #2 ‘targets’ is on CPU, but expected it to be on GPU (while checking arguments for ctc_loss_gpu)
解决方法,https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-524738484, https://github.com/pytorch/pytorch/issues/14401#issuecomment-447710313
2, loss为NAN
解决方法,https://github.com/ypwhs/captcha_break/issues/29,
https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-517516922
3,在使用torchsummary打印模型各层形状不错**‘tuple’ object has no attribute ‘size’**
这是pytorch summary的一个bug,因为pytorch在写LSTM的时候的output是一个tuple,除了output之外还有所谓的hn,cn然而summary无法识别tuple,所以才会报这个错误。