Error:w tensorflow/core/util/ctc/ctc_loss_calculator.cc:144] No valid path found.

Error:w tensorflow/core/util/ctc/ctc_loss_calculator.cc:144] No valid path found.

文本识别:DenseNet + CTC(如果不想修改输入序列长度的情况下)

1.在使用自己的数据训练这个模型的过程中对于这个error查了很多人的回答,但是还是解决不了,大部分是说CTC在计算loss时要求输入序列长度不小于标签长度,也就是标签长度要小于输入序列长度。比如输入序列长度为25,那么标签长度不得大于25,还有的说不得大于25-1等等。我都设置了,均不成效!!!

2.另一种说法是再计算CTCloss的时候添加一个参数:preprocess_collapse_repeated = True

loss = tf.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len)

----->

loss = tf.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len,preprocess_collapse_repeated = True)

告诉大家这个参数的含义是啥呢,就是在预处理的时候将所有重复的字符都合并成一个标签。什么意思呢,最终的到的模型在预测字符2333的时候,返回的是23,这是很坑的一点,所以除非是bu想有重复的字符出现否则不要尝试!!!

Error:w tensorflow/core/util/ctc/ctc_loss_calculator.cc:144] No valid path found._第1张图片

3.最后,我记不清在哪里看到的回答, CTC在计算loss时要求不光标签长度要小于输入序列长度,而且标签长度加上重复字符的个数总和要小于输入序列长度。就是需要加上e的数量。比如字符23333,那么总长度是5+4=9。所以我把数据集中的标签都过滤了一遍:

def data_filter(train_path,valid_path,length):
    train_data_list = []
    train_list=os.listdir(train_path)
    for train_ in train_list:
        if train_.endswith(".txt"):
            pass
        else:
            with open(train_path+train_+'/train_word.txt', 'r')as f:
                line = f.readline()
                while line:
                    if ',' in line:
                        one_data = line.split(',')
                        label=one_data[1].split(' ')
                    else:
                        one_data = line.split(' ')
                        label = one_data[1:]
                    label =  [ int(x) - 1 for x in label ]
                    runs = [len(list(g)) for _, g in groupby(label)]
                    repeats = sum(u for u in runs if u > 1)
                    if  len(label)+repeats 1)
                    if  len(label)+repeats
length=100
train_data_list,valid_data_list=data_filter('','',length/4)

最终过滤完数据之后再来训练,error不做出现!! 

如果有什么错误请指正哈,感谢-_-

 

你可能感兴趣的:(Code-error)