CNN+GRU+CTC实现不定长字符串识别(二)

对识别结果进行投票

    • 介绍
    • 投票思路
    • 源码
    • 结果

介绍

一拳难敌四手,对于模型来说也是这样,单个模型的准确率终究还是不如多个模型综合起来准确率高,这里我简单的训练了四个模型,找一找他们之间的关系。
我用的训练集依然是tinymind的人民币编码识别,我训练了三种模型,ResNet * 1,DenseNet * 1,Xception * 2,因为时间和硬件问题也没有训练很多,正确率分别如下:
ResNet:
在这里插入图片描述
DenseNet:
在这里插入图片描述
Xception:
在这里插入图片描述
这个平台还是很好用的,不但有准确率,还能看损失

投票思路

CNN+GRU+CTC实现不定长字符串识别(二)_第1张图片

源码

'''
筛选
如果最后一位是Z,说明识别时错位,直接删掉
前三位可能是数字或字母,后七位只能是数字
出现错误的label将会直接删除,如果全部出错,label改为1111111111,需手动操作
投票
少数服从多数
'''

import os
vote_dir = './vote_dir/'
result_dir = './result/'
str1 = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
str1_list = list(str1)

if os.path.exists(vote_dir) == False:
    os.makedirs(vote_dir)
if os.path.exists(result_dir) == False:
        os.makedirs(result_dir)

dir_list = os.listdir(vote_dir)
label_num = 9988 #标签数量,不算第一行

label_list = [['0' for i in range(len(dir_list))]for i in range(label_num)]
print(label_list)

def check(label):
    for i in range(len(label)):
        lab = label[i]
        lab = lab[14:]
        lab = list(lab)
        print('筛选{}'.format(lab))
        for n in range(len(lab)):
            l_index = str1_list.index(lab[n])
            #print('{}的index{}'.format(lab[i], l_index))
            if n > 2 and l_index > 9:
                if len(label) == 1:
                    return [label[0][:14]+'1111111111'],True
                label.remove(label[i])
                return label,False
    return label,True

for i in range(len(dir_list)):
    with open(vote_dir + dir_list[i],'r') as file:
        tlist = file.readlines()
        for n in range(len(tlist)):
            if n == 0:
                continue
            text = tlist[n][:-1]
            label_list[n-1][i] = text
            #print('{},{},{}'.format(n,i,text))

with open(result_dir+'result.csv','w') as result:
    result.writelines('name,label'+'\n')
    for label in label_list:
        print(label)
        #筛选
        while True:
            label,re = check(label)
            if re == True:
                break
        #投票
        one_label = []
        for i in range(10):
            temp = []
            for n in range(len(label)):
                lab = label[n][14:]
                lab = list(lab)
                lab = lab[i]
                temp.append(lab)
            lab = max(set(temp), key=temp.count)
            one_label.append(lab)
        one_label = ''.join(one_label)
        label = label[0][:14] + one_label
        result.writelines(label+'\n')
        print(label)

结果

在这里插入图片描述
结果就是相比较最高的98.5,正确率提高了0.4个百分点,当然也可能是我写的投票脚本不够好,但是也可以看出,正确率的高低最主要还是要看单个模型的正确率,投票进行融合只能说锦上添花,并不能起决定性作用,同时我还考虑了是不是因为我那个DenseNet的模型训练的不够好,所以拉低了平均水平,我把DenseNet的结果删了之后,用剩下的3个结果又进行了一次融合,正确率并没有提高
在这里插入图片描述
反而降低了0.05,所以说如果模型训练的不是特别差的话还是越多越好。

你可能感兴趣的:(python,计算机视觉,结果投票,序列识别)