车牌识别文字识别crnn_plate_recognition训练以及代码解析

车牌识别文字识别训练全过程解析 目前代码解读还不算完善 后续会补充

车牌识别github链接

车牌识别文字识别github链接

车牌检测end2end实现过程

训练方式按照github上介绍就行

在解释前定义几个方便理解

plate_chr="#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航0123456789ABCDEFGHJKLMNPQRSTUVWXYZ危险品"  
plate_name="京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航0123456789ABCDEFGHJKLMNPQRSTUVWXYZ危险品"  
plateDict2={'京':0, '京':1, '沪':2, ......, '-': 77}  
plateDict={'#':0, '京':1, '沪':2 ......}  
pchar表示车牌单个字符 如'京'  
pstr表示每一个车牌字符串 如'云A008BC'  
p_number表示车牌字符对应的数字, 即plateDict中的0

解析数据集打上标签,生成train.txt和val.txt的程序

  • 生成程序
    python plateLabel.py --image_path your/train/img/path/ --label_file datasets/train.txt
    python plateLabel.py --image_path your/val/img/path/ --label_file datasets/val.txt
    
  • plateLabel.py解析
    import os
    import argparse
    from alphabets import plate_chr  # 导入车牌可能出现的所有字符
    # 遍历rootfile文件下所有图片 
    def allFileList(rootfile,allFile):
        folder =os.listdir(rootfile)
        for temp in folder:
            fileName = os.path.join(rootfile,temp)
            if os.path.isfile(fileName):
                allFile.append(fileName)
            else:
                allFileList(fileName,allFile)
    # 判断车牌名是不是在palteStr中  当车排名不在plateStr中的 return False
    def is_str_right(plate_name):
        for str_ in plate_name:
            if str_ not in palteStr:
                return False
        return True
    
    if __name__=="__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument('--image_path', type=str, default="datasets/val", help='source') 
        parser.add_argument('--label_file', type=str, default='datasets/val.txt', help='model.pt path(s)')  
        
        opt = parser.parse_args()
        rootPath = opt.image_path
        labelFile = opt.label_file
        # palteStr=r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民深危险品0123456789ABCDEFGHJKLMNPQRSTUVWXYZ"
        # palteStr=r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航深0123456789ABCDEFGHJKLMNPQRSTUVWXYZ"
        palteStr=plate_chr
        print(len(palteStr))
    
        # 生成一个字典plateDict
        plateDict ={}
        for i in range(len(list(palteStr))):
            plateDict[palteStr[i]]=i
        fp = open(labelFile,"w",encoding="utf-8")
        file =[]
        allFileList(rootPath,file)  # 遍历rootPath下所有图片  保存在file中
        picNum = 0
    
        # 遍历每一张图片
        for jpgFile in file:
            print(jpgFile)
            jpgName = os.path.basename(jpgFile)  # 获得图片名称  如: 云A008BC_0.jpg
            name =jpgName.split("_")[0]  # 获得车牌文字pstr  如: 云A008BC
            if " " in name:
                continue
            labelStr=" "
            if not is_str_right(name):  # 如果车牌文字pstr存在不在plateDict中的字符pchar 则直接continue
                continue
            strList = list(name)  # 将车牌文字转化为列表 如: ['云','A','0','0','8','B','C']
            for  i in range(len(strList)):
                labelStr+=str(plateDict[strList[i]])+" "  # 将车牌文字转化为对应的数字p_number 如: "25 52 42 42 50 53 54"
            # while i<7:
            #     labelStr+=str(0)+" "
            #     i+=1
            picNum+=1
            # print(jpgFile+labelStr)
            fp.write(jpgFile+labelStr+"\n")  # 将图片路径和对应的标签写入labelFile中  如 datasets/val\云A008BC_0.jpg 25 52 42 42 50 53 54 
        fp.close()
    

代码解析 按照train.py的代码一步一步解析, 只阐述重点的地方

训练代码解析
  • 加载config: config = parse_arg()

    def parse_arg():
        parser = argparse.ArgumentParser(description="train crnn")
        
        parser.add_argument('--cfg', help='experiment configuration filename', default='./lib/config/360CC_config.yaml', type=str)  # 配置文件
        parser.add_argument('--img_h', type=int, default=48, help='height')  # 模型input的h
        parser.add_argument('--img_w',type=int,default=168,help='width')     # 模型input的w
        args = parser.parse_args()
       
        with open(args.cfg, 'r') as f:
            # config = yaml.load(f, Loader=yaml.FullLoader)
            config = yaml.load(f, Loader=yaml.FullLoader)
            config = edict(config)  # 将config转化为edict形式的  即从config['DATASET']['ALPHABETS']变成config.DATASET.ALPHABETS']
    
        config.DATASET.ALPHABETS = plateName  # 字符集plate_name 比plate_chr少了一个blank字符"#"
        config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)  # 字符集plate_name长度 77
        config.HEIGHT=args.img_h  # 输入图片的h
        config.WIDTH = args.img_w  # 输入图片的w
        return config
    
  • 所有保存文件的输出路径: output_dict = utils.create_log_folder(config, phase=‘train’)

    def create_log_folder(cfg, phase='train'):
        root_output_dir = Path(cfg.OUTPUT_DIR)  
        # set up logger
        if not root_output_dir.exists():  
            print('=> creating {}'.format(root_output_dir))
            root_output_dir.mkdir()
    
        dataset = cfg.DATASET.DATASET  #数据集名称  '360CC'
        model = cfg.MODEL.NAME  #模型名称  'crnn'
    
        time_str = time.strftime('%Y-%m-%d-%H-%M')  #时间  2023-12-14-16-09
        checkpoints_output_dir = root_output_dir / dataset / model / time_str / 'checkpoints'  #输出文件路径
    
        print('=> creating {}'.format(checkpoints_output_dir))
        checkpoints_output_dir.mkdir(parents=True, exist_ok=True)
    
        tensorboard_log_dir = root_output_dir / dataset / model / time_str / 'log'  #tensotborad日志路径
        print('=> creating {}'.format(tensorboard_log_dir))
        tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
    
    
        return {'chs_dir': str(checkpoints_output_dir), 'tb_dir': str(tensorboard_log_dir)}
    
  • 数据集的读取

    class _360CC(data.Dataset):
        def __init__(self, config, input_w=168,input_h=48,is_train=True):
    
            self.root = config.DATASET.ROOT
            self.is_train = is_train
            self.inp_h = config.MODEL.IMAGE_SIZE.H
            self.inp_w = config.MODEL.IMAGE_SIZE.W
            self.input_w = input_w  # 输入图片的宽
            self.input_h= input_h   # 输入图片的高
            self.dataset_name = config.DATASET.DATASET
    
            self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
            self.std = np.array(config.DATASET.STD, dtype=np.float32)
    
            char_file = config.DATASET.CHAR_FILE
            # with open(char_file, 'rb') as file:
            #     char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
            # with open(char_file, 'r',encoding='utf-8') as file:
            #     char_dict = {num: char.strip()  for num, char in enumerate(file.readlines())}
                # I resaved char_std_5990.txt in utf-8 format, so no need decode gbk
                # char_dict = {num: char.strip() for num, char in enumerate(file.readlines())}
            char_dict = {num:char.strip() for num,char in enumerate(plate_chr)}
            char_dict[0]="blank"  #训练的字符字典plateDict中第一个代表的是空白, 这个跟CTCLoss有关, 可以看一看CTCLoss后就可以理解了
            txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
    
            # convert name:indices to name:string
            self.labels = []
            with open(txt_file, 'r', encoding='utf-8') as file:
                contents = file.readlines()
                for c in contents:
                    c=c.strip(" \n")
                    imgname = c.split(' ')[0]
                    indices = c.split(' ')[1:]
                    string = ''.join([char_dict[int(idx)] for idx in indices])
                    self.labels.append({imgname: string})
    
            print("load {} images!".format(self.__len__()))
    
        def __len__(self):
            return len(self.labels)
    
        def __getitem__(self, idx):
    
            img_name = list(self.labels[idx].keys())[0]
            # img = cv2.imread(os.path.join(self.root, img_name))
            img = cv_imread(os.path.join(self.root, img_name))
            if img.shape[-1]==4:
                img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
            # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
            img_h, img_w ,_= img.shape
    
            # img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
            img = cv2.resize(img, (self.input_w,self.input_h))
            # img = np.reshape(img, (48, 168, 3))
            # img = np.reshape(img, (self.inp_h, self.inp_w, 1))
    
            img = img.astype(np.float32)
            img = (img/255. - self.mean) / self.std
            img = img.transpose([2, 0, 1])  #[h, w, c] -> [c, h, w] 这里没有brg -> rgb 在end2end预测的时候也没有 所以是可以的
    
            return img, idx
    
  • 模型的训练过程中也有一些注意的事项 ———— lib/core/function.py (def train)

    # 注意这里的idx还是索引, 可以从上面的数据集读取上看到return idx
    # labels: ['苏A8C4A8', '川AE00K0', '冀EL2392', '鲁ATN619', '川A5E1Z9', '闽FQZ790', '辽MD7792'......] len=256
    labels = utils.get_batch_label(dataset, idx)
    
    # inference
    # preds: torch.Size([21, 256, 78]) 
    # 21: 车牌预测的字符个数的最大上限 也就是一张车牌最多预测21个字符pchar
    # 256:图片的batchsize 
    # 78:车牌字符集:77 + 1 (77是车牌字符串plate_name的长度 1是blank, 相当于#)
    # 这个空白#的存在其实和CTCLoss有关 这里就不过多介绍了
    preds = model(inp).cpu()
    

    计算损失这里可能难以理解

    # compute loss
    
    # batchsize: 256
    batch_size = inp.size(0)
    
    # text:  tensor([11, 52, 50,  ..., 47, 42, 47], dtype=torch.int32)  shape: torch.Size([1798])
    # length:  tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ......]  shape:torch.Size([256])
    # 从上面的输出结果可以看出text中的是将labels中所有的车牌字符串pstr都拼接在一起, 其中的值代表的是每一个plate_chr{#京沪......}对应的下标
    # length中的值则可以很轻松的看出是每一个车牌pstr的长度
    text, length = converter.encode(labels)                    # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
    
    # preds_size:  tensor([21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, ......]  shape:torch.Size([256])
    preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
    
    # torch官网上的CTCLoss的使用的参数要求 可以直接看官网 官网更详细
    # preds:  (T, N, C)  T=input length  N=batch size  C=number of classes(including blank)
    # text: (N, S) or (sum(target_lengths)) sum就是将所有的字符串pstr拼接在一起并转化为对应的plateDict下标  其中0是blank
    # preds_size: (N, )
    # length: (N, )
    loss = criterion(preds, text, preds_size, length)
    

    CTCLoss的解析推荐看这两篇文章
    知乎
    博客园

  • text, length = converter.encode(labels)中的converter问题: converter = utils.strLabelConverter(config.DATASET.ALPHABETS)

    # encode的作用就是利用{'#': 0, '京': 1, ......}中的一一对应关系 将车牌名字转化为对应的数字
    def encode(self, text):
        """Support batch or single str.
    
        Args:
            text (str or list of str): texts to convert.
    
        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """
    
        length = []
        result = []
        decode_flag = True if type(text[0])==bytes else False
    
        for item in text:
    
            if decode_flag:
                item = item.decode('utf-8','strict')
            length.append(len(item))
            for char in item:
                index = self.dict[char]
                result.append(index)
        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))
    

训练过程中的验证代码解析

  • 模型验证过程中也会有一些注意的地方 ———— lib\core\function.py (def validate)

    # preds: shape:(21, 128, 78)
    # 在max之后 _: 是最大值 shape(21, 128)  preds: 是最大值的索引 shape(21, 128)
    # 这个部分主要是选出车牌字符串集plateDict 78中最大概率的那个作为该位置的输出
    _, preds = preds.max(2)
    
    # preds: torch.Size([2688])
    # 先转化为[128, 21], 主要是为了decode中每一个相邻的21个位置都是同一个车牌上的预测结果
    preds = preds.transpose(1, 0).contiguous().view(-1)
    
    # preds: tensor([30,  0,  0,  ..., 43, 43,  0])  torch.Size([2688]) 
    # preds_size: tensor([21, 21, 21, ......])  shape: torch.Size([128])  
    sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
    
  • decode代码解析 lib\utils\utils.py (def encode())
    这里的converter.decode是跟CTCLoss进行配合的, 需要好好理解一下
    这个部分的作用首先是将预测输出的torch.size([128, 21])中的所有21个对应的最大概率的索引(下标为0-77)转化为对应的plateDict2中的字符, 这里为什么不是{‘-’: 0, …}即blank作为0, 这个原因可以见decode代码解析
    然后将转化好的21个字符中去掉重复的字符, 直接举个例子吧, -代表是blank, 以这个间隔, 删掉重复的字母
    左边是转化好的21个字符, 右边是得到的字符

    ---E--D--3-S-11-22-- => 苏ED3S12             , gt: 苏ED3S12
    赣--EE--K--1--3-22-0-- => 赣EK1320             , gt: 赣EK1320
    渝-AA---66-P--8--6-11- => 渝A6P861             , gt: 渝A6P861
    新--LL--66-0-11-77-99- => 新L60179             , gt: 新L60179
    鄂---N--66-MM-Y-11-22- => 鄂N6MY12             , gt: 鄂N6MY12
    豫---B--D--8--2-11-11- => 豫BD8211             , gt: 渝BD8211
    苏--E---1--1-LL-33-LL- => 苏E11L3L             , gt: 苏E11L3L
    赣---C--55-55-T--6-22- => 赣C55T62             , gt: 赣C55T62
    川-AA---88-8--Y--3-55- => 川A88Y35             , gt: 川A88Y35
    例如 川-AA---88 => 就会删掉一个A ,删掉所有的-, 删掉一个8, 得到川A8
    
    def decode(self, t, length, raw=False):
        """Decode encoded texts back into strs.
    
        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
    
        Raises:
            AssertionError: when the texts and its length does not match.
    
        Returns:
            text (str or list of str): texts to convert.
        """
        # 这个就是一个[21]的车牌字符的转化: 川-AA---88-8--Y--3-55- => 川A88Y35
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    # t[i] != 0 代表的是不为'-', not(i>0 and t[i-1]==t[i])表示的是当是第一个字符或者前后两个字符是相同的时候为True
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])  #将对应的索引转化为车牌字符
                return ''.join(char_list)
        # 这个是将[128, 21]分开为128个21, 即每一个车牌单独的送入上面的if length.numel() == 1:中
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts
    
  • 验证集指标的出现 ———— lib\core\function.py (def validate)

    for pred, target in zip(sim_preds, labels):
        sum+=1
        if pred == target:
            n_correct += 1
    
    accuracy = n_correct / sum
    

    很明显, 这个指标是完全预测准确的车牌/总的预测的车牌

最终训练结果展示

(lvxiaoleother) C:\Users\HUST\Desktop\crnn_plate_recognition> c: && cd c:\Users\HUST\Desktop\crnn_plate_recognition && cmd /C "C:\Users\HUST\anaconda3\envs\lvxiaoleother\python.exe c:\Users\HUST\.vscode\extensions\ms-python.python-2023.4.1\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher 52925 -- C:\Users\HUST\Desktop\crnn_plate_recognition\train.py "
=> creating output\360CC\crnn\2023-12-14-17-17\checkpoints
=> creating output\360CC\crnn\2023-12-14-17-17\log

layer                                               name  gradient   parameters                shape           mu        sigma
    0                                   feature.0.weight      True         1200        [16, 3, 5, 5]     -0.00129       0.0672
    1                                     feature.0.bias      True           16                 [16]     -0.00503       0.0724
    2                                   feature.1.weight      True           16                 [16]            1            0
    3                                     feature.1.bias      True           16                 [16]            0            0
    4                                   feature.3.weight      True         2304       [16, 16, 3, 3]      0.00171       0.0482
    5                                     feature.3.bias      True           16                 [16]       0.0188       0.0519
    6                                   feature.4.weight      True           16                 [16]            1            0
    7                                     feature.4.bias      True           16                 [16]            0            0
    8                                   feature.6.weight      True         4608       [32, 16, 3, 3]     4.04e-05       0.0484
    9                                     feature.6.bias      True           32                 [32]     -0.00626       0.0452
   10                                   feature.7.weight      True           32                 [32]            1            0
   11                                     feature.7.bias      True           32                 [32]            0            0
   12                                   feature.9.weight      True         9216       [32, 32, 3, 3]     0.000206       0.0341
   13                                     feature.9.bias      True           32                 [32]      0.00341       0.0318
   14                                  feature.10.weight      True           32                 [32]            1            0
   15                                    feature.10.bias      True           32                 [32]            0            0
   16                                  feature.13.weight      True        18432       [64, 32, 3, 3]     6.27e-05        0.034
   17                                    feature.13.bias      True           64                 [64]       0.0084       0.0305
   18                                  feature.14.weight      True           64                 [64]            1            0
   19                                    feature.14.bias      True           64                 [64]            0            0
   20                                  feature.16.weight      True        36864       [64, 64, 3, 3]     6.22e-07       0.0241
   21                                    feature.16.bias      True           64                 [64]     -0.00337       0.0243
   22                                  feature.17.weight      True           64                 [64]            1            0
   23                                    feature.17.bias      True           64                 [64]            0            0
   24                                  feature.20.weight      True        55296       [96, 64, 3, 3]    -3.25e-05       0.0241
   25                                    feature.20.bias      True           96                 [96]      0.00238       0.0236
   26                                  feature.21.weight      True           96                 [96]            1            0
   27                                    feature.21.bias      True           96                 [96]            0            0
   28                                  feature.23.weight      True        82944       [96, 96, 3, 3]     6.25e-06       0.0196
   29                                    feature.23.bias      True           96                 [96]    -0.000941         0.02
   30                                  feature.24.weight      True           96                 [96]            1            0
   31                                    feature.24.bias      True           96                 [96]            0            0
   32                                  feature.27.weight      True       110592      [128, 96, 3, 3]    -2.34e-05       0.0196
   33                                    feature.27.bias      True          128                [128]    -0.000311       0.0204
   34                                  feature.28.weight      True          128                [128]            1            0
   35                                    feature.28.bias      True          128                [128]            0            0
   36                                  feature.30.weight      True       294912     [256, 128, 3, 3]     2.84e-05        0.017
   37                                    feature.30.bias      True          256                [256]      0.00124        0.016
   38                                  feature.31.weight      True          256                [256]            1            0
   39                                    feature.31.bias      True          256                [256]            0            0
   40                                      newCnn.weight      True        19968      [78, 256, 1, 1]     0.000316        0.036
   41                                        newCnn.bias      True           78                 [78]    -0.000652       0.0396
Model Summary: 42 layers, 638814 parameters, 638814 gradients

load 62863 images!
load 2014 images!
Epoch: [0][0/246]       Time 1200.660s (1200.660s)      Speed 0.2 samples/s     Data 11.834s (11.834s)  Loss 10.75087 (10.75087)
Epoch: [0][100/246]     Time 0.060s (16.706s)   Speed 4266.7 samples/s  Data 0.002s (0.236s)    Loss 0.80793 (2.61764)
Epoch: [0][200/246]     Time 0.063s (8.425s)    Speed 4063.5 samples/s  Data 0.002s (0.120s)    Loss 0.11421 (1.48051)---E--D--3-S-11-22-- => 苏ED3S12             , gt: 苏ED3S12
赣--EE--K--1--3-22-0-- => 赣EK1320             , gt: 赣EK1320
渝-AA---66-P--8--6-11- => 渝A6P861             , gt: 渝A6P861
新--LL--66-0-11-77-99- => 新L60179             , gt: 新L60179
鄂---N--66-MM-Y-11-22- => 鄂N6MY12             , gt: 鄂N6MY12
豫---B--D--8--2-11-11- => 豫BD8211             , gt: 渝BD8211
苏--E---1--1-LL-33-LL- => 苏E11L3L             , gt: 苏E11L3L
赣---C--55-55-T--6-22- => 赣C55T62             , gt: 赣C55T62
川-AA---88-8--Y--3-55- => 川A88Y35             , gt: 川A88Y35
渝--BB--NN-77-9-88-66- => 渝BN7986             , gt: 渝BN7986
1540
128000
Test loss: 0.1785, accuray: 0.7646
is best: True
best acc is: 0.7646474677259185

你可能感兴趣的:(车牌识别,深度学习,python)