ctcloss理解及ctcloss使用报错总结

ctcloss理解及ctcloss使用报错总结

ctcloss函数主要用在没有事先对齐的序列化数据训练上,比如语音识别,ocr识别等,主要的优点是可以对没有对齐的数据进行自动对齐。

  1. L = a , o , e , i , u , b , p , m , f , ⋯ L={a,o,e,i,u,b,p,m,f,\cdots} L=a,o,e,i,u,b,p,m,f, 表示所有字符的集合。

  2. π = ( π 1 , π 2 , ⋯   , π T ) , π i ε L π=(π_1,π_2,\cdots,π_T),π_i\varepsilon L π=(π1,π2,,πT),πiεL 表示一条由L中元素组成的长度为T的路径,表示模型的输出序列。

  3. l = ( l 1 , l 2 , ⋯   , l m ) , l i ε L l = (l_1,l_2,\cdots,l_m),l_i\varepsilon L l=(l1,l2,,lm)liεL, 表示一条由L中元素组成的长度为m的路径,即为ground truth标签。

  4. y k t ( k = 1 , 2 ⋯   , n , t = 1 , 2 ⋯   , T ) y^t_k(k=1,2\cdots,n,t=1,2\cdots,T) ykt(k=1,2,n,t=1,2,T):表示在t时刻,输出为k的概率,举个简单的例子:当输出的序列为(a-ab-)时, y a 3 y^3_a ya3代表了在第3步输出的字母为a的概率;

  5. p ( π ∣ x ) p(π|x) p(πx):代表了给定输入x下,输出路径为 π π π的概率。

    由于假设在每一个时间步输出的label的概率都是相互独立的,那么 p ( π ∣ x ) p(π|x) p(πx)用公式来表示为 p ( π ∣ x ) = ∏ t = 1 T y k t p(π|x)=\prod_{t=1}^{T}{y^t_k} p(πx)=t=1Tykt,可以理解为每一个时间步输出路径 $ π$ 的相应label的概率的乘积。

  6. p ( π ∣ x ) = p ( π ∣ y = N w ( x ) ) = p ( π ∣ y ) = ∏ t = 1 T y k t p(π|x)=p(π|y=N_w(x))=p(π|y)=\prod_{t=1}^{T}{y^t_k} p(πx)=p(πy=Nw(x))=p(πy)=t=1Tykt ,y是经过softmax层得到的后验概率, N w ( x ) N_w(x) Nw(x)可理解为DNN 模型对输入提取feature后输出概率的一种变换。

  7. p ( l ∣ x ) p(l|x) p(lx):代表给定输入x,输出为序列$l $的概率

    因此输出的序列为 $l $ 的概率可以表示为所有输出的路径 π 映射后的序列为 l 的概率之和,用公式表示为

    p ( l ∣ y = N w ( x ) ) = p ( l ∣ x ) = ∑ p ( π ∣ x ) p(l|y=N_w(x))=p(l|x)=\sum{p(π|x)} p(ly=Nw(x))=p(lx)=p(πx) subject to. { π ∣ B ( π ) = l {π|B(π)=l} πB(π)=l} ,其中{ π ∣ B ( π ) = l {π|B(π)=l} πB(π)=l}表示所有输出的路径π做压缩映射变换后等于$l $的路径集合。

总结:给定样本x后输出正确label 为$l $的概率的计算公式如下:

p ( l ∣ x ) = p ( l ∣ y = N w ( x ) ) = ∑ p ( π ∣ x ) = ∑ ∏ t = 1 T y k t p(l|x)=p(l|y=N_w(x))=\sum{p(π|x)}=\sum\prod_{t=1}^{T}{y^t_k} p(lx)=p(ly=Nw(x))=p(πx)=t=1Tykt

ctcloss表达式为

ctcloss = − l n p ( l ∣ x ) -lnp(l|x) lnp(lx)

ctcloss就是计算一个输入序列x,模型预测序列中所有能映射到标签label的输出序列概率总和的负ln

而连乘再求和的计算可以采用 前置项* y k t y^t_k ykt*后置项 计算。

举例如下
ctcloss理解及ctcloss使用报错总结_第1张图片

然后前置项,后置项通过动态规划即可求解。

推荐参考文章

https://blog.csdn.net/huangyiping12345/article/details/102668605

https://blog.csdn.net/luodongri/article/details/77005948

https://www.jianshu.com/p/0cca89f64987

pytorch内置nn.CTCLoss方法详解

https://zhuanlan.zhihu.com/p/67415439

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

我们在crnn+ctc训练文字识别项目时,

log_probs:模型输出张量shape为(T, B, C) ,其中T是模型输出时图像的宽度,一般称为input_length也即输出序列长度,此值是受模型输入时图像的width大小所影响,B是batch_size大小,C是包括空白标签的字符集的长度。

input_lengths:张量shape为(B, ) 常用preds_size = torch.IntTensor([preds.size(0)] * batch_size)得到此张量,preds.size(0)就是输入序列长度。

targets:张量shape(sum(target_lengths), ),每个元素是batch_size个text标签中的字符依次对应的索引,如前10个元素是第一个text标签的字符的索引,依次6个元素是第二个标签的字符索引。这里的标签不包括空白标签。张量shape也可以为(B, S),S则为标签长度,意味着需要将batch个label的one-hot标量组装成相同长度。

**targets_lengths:张量shape为(B,),每个元素是每个样本的text标签长度,标签长度可能各不相同。**每个元素分别指定了按序取targets多少个元素来表示一个text标签:

def encode(self, text):
        """Support batch or single str.

        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """

        length = []
        result = []
        decode_flag = True if type(text[0])==bytes else False

        for item in text:

            if decode_flag:
                item = item.decode('utf-8','strict')
            length.append(len(item))
            for char in item:
                index = self.dict[char]
                result.append(index)
        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))
    
    
 # ctc使用时常常看到这样的用法:
 preds = crnn(image)  # preds shape为(T,  B,  C) 
 # Tensor for argument #2 'targets' is on CPU, but expected it to be on GPU (while        #checking arguments for ctc_loss_gpu)
 # https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-517516922
 # https://github.com/ypwhs/captcha_break/issues/29
 # https://github.com/pytorch/pytorch/issues/14401#issuecomment-447710313
 # https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-524738484
 preds = preds.to(torch.float64)  # 此行是解决pytorch ctcloss的内部bug
 preds = preds.to(device)
 batch_size = image.size(0)
 text, length = converter.encode(label)  # encode就是上面这个函数
 text = text.to(device)
 preds_size = torch.IntTensor([preds.size(0)] * batch_size)
 cost = criterion(preds, text, preds_size, length) / batch_size #  criterion就是ctc_loss

注意:CTCLoss 要求 input_length >= 2 * target_length + 1, input_length是模型输出时的序列长度,target_length是label的字符长度

pytorch ctcloss 使用时常常出现的错误

1, Tensor for argument #2 ‘targets’ is on CPU, but expected it to be on GPU (while checking arguments for ctc_loss_gpu)

解决方法,https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-524738484, https://github.com/pytorch/pytorch/issues/14401#issuecomment-447710313

2, loss为NAN

解决方法,https://github.com/ypwhs/captcha_break/issues/29,

https://github.com/Sierkinhane/crnn_chinese_characters_rec/issues/124#issuecomment-517516922

3,在使用torchsummary打印模型各层形状不错**‘tuple’ object has no attribute ‘size’**

这是pytorch summary的一个bug,因为pytorch在写LSTM的时候的output是一个tuple,除了output之外还有所谓的hn,cn然而summary无法识别tuple,所以才会报这个错误。

你可能感兴趣的:(计算机视觉)