笔记记录--基于ccpd数据集利用Paddle OCR训练车牌检测模型

目录

1-- 环境搭建

2-- 数据集划分

3-- 训练模型

4-- 推理测试


1-- 环境搭建

安装Paddle OCR参考

① 创建环境

conda create -n paddle_env python=3.8
conda activate paddle_env

② 安装paddlepaddle
# 切换cuda版本为11.1(根据个人实际修改)

sudo gedit ~/.bashrc
source ~/.bashrc

# 安装paddlepaddle

python -m pip install paddlepaddle-gpu==2.3.0.post111 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 

③ 安装依赖
# 安装PaddleOCR whl

pip install "paddleocr>=2.0.1"

# 版面分析

pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl

2-- 数据集划分

# 采用ccpd数据集的challenge系列
# 45003用于训练集,5000用于验证集

数据集文件转移代码

import shutil
import os

def remove_file(old_path, new_path):
    filelist = os.listdir(old_path)  # 列出该目录下的所有文件,listdir返回的文件列表是不包含路径的。
    i = 0
    for file in filelist:
        src = os.path.join(old_path, file)
        dst = os.path.join(new_path, file)
        
        if i < 5000:
            shutil.move(src, dst)
        else:
            break
        i = i + 1
        
if __name__ == '__main__':
    remove_file(r"/civi/Chinese_license_plate_Note/detection/dataset/ccpd_challenge", r"/civi/Chinese_license_plate_Note/detection/dataset/test_dataset")

# 创建label标注文件

格式参考

笔记记录--基于ccpd数据集利用Paddle OCR训练车牌检测模型_第1张图片

 代码样例

import os

words_list = [
    "A", "B", "C", "D", "E",
    "F", "G", "H", "J", "K",
    "L", "M", "N", "P", "Q",
    "R", "S", "T", "U", "V",
    "W", "X", "Y", "Z", "0",
    "1", "2", "3", "4", "5",
    "6", "7", "8", "9"
]

con_list = [
    "皖", "沪", "津", "渝", "冀",
    "晋", "蒙", "辽", "吉", "黑",
    "苏", "浙", "京", "闽", "赣",
    "鲁", "豫", "鄂", "湘", "粤",
    "桂", "琼", "川", "贵", "云",
    "西", "陕", "甘", "青", "宁",
    "新"
]

if __name__ == "__main__":
    points = []
    label = []
    for item in os.listdir(os.path.join('/civi/Chinese_license_plate_Note/detection/dataset/test_dataset/')): # 遍历图片

        _, _, bbox, points, label, _, _ = item.split('-') # 分割文件名

        points = points.split('_') # 分割四个坐标点
        tmp = points
        points = []
        for _ in tmp:
            points.append([int(_.split('&')[0]), int(_.split('&')[1])])
        # print(points)

        label = label.split('_')
        con = con_list[int(label[0])]
        words = [words_list[int(_)] for _ in label[1:]]
        label = con + ''.join(words)
        label = '"' + label + '"'
        file_name = item
        List = '[{"transcription": ' + label + ', "points": ' + str(points) + '}]'
        line = file_name + '\t' + List + '\n'

        with open('/civi/Chinese_license_plate_Note/detection/dataset/' + 'test_label.txt', 'a', encoding='UTF-8') as f:
            f.write(line)

上述代码博主犯了一个错误,就是CCPD数据集的四个坐标是从右下坐标顺时针开始的,而OCR检测的标注文件,其坐标要求从左上顺时针开始,所以上述代码修改为:

import os

words_list = [
    "A", "B", "C", "D", "E",
    "F", "G", "H", "J", "K",
    "L", "M", "N", "P", "Q",
    "R", "S", "T", "U", "V",
    "W", "X", "Y", "Z", "0",
    "1", "2", "3", "4", "5",
    "6", "7", "8", "9"
]

con_list = [
    "皖", "沪", "津", "渝", "冀",
    "晋", "蒙", "辽", "吉", "黑",
    "苏", "浙", "京", "闽", "赣",
    "鲁", "豫", "鄂", "湘", "粤",
    "桂", "琼", "川", "贵", "云",
    "西", "陕", "甘", "青", "宁",
    "新"
]

if __name__ == "__main__":
    points = []
    label = []
    for item in os.listdir(os.path.join('/civi/Chinese_license_plate_Note/detection/dataset/test_dataset/')): # 遍历图片

        _, _, bbox, points, label, _, _ = item.split('-') # 分割文件名

        points = points.split('_') # 分割四个坐标点
        tmp = points
        points = []
        points.append([int(tmp[2].split('&')[0]), int(tmp[2].split('&')[1])])
        points.append([int(tmp[3].split('&')[0]), int(tmp[3].split('&')[1])])
        points.append([int(tmp[0].split('&')[0]), int(tmp[0].split('&')[1])])
        points.append([int(tmp[1].split('&')[0]), int(tmp[1].split('&')[1])])

        label = label.split('_')
        con = con_list[int(label[0])]
        words = [words_list[int(_)] for _ in label[1:]]
        label = con + ''.join(words)
        label = '"' + label + '"'
        file_name = item
        List = '[{"transcription": ' + label + ', "points": ' + str(points) + '}]'
        line = file_name + '\t' + List + '\n'

        with open('/civi/Chinese_license_plate_Note/detection/dataset/' + 'test_label.txt', 'a', encoding='UTF-8') as f:
            f.write(line)

3-- 训练模型

训练参考

① 下载预训练模型(DB-Net)

 下载地址

② 配置config文件

示例:(注释部分需要留意并修改)

Global:
  use_gpu: true # 是否使用gpu
  epoch_num: 200  # epoch数目
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: /civi/Chinese_license_plate_Note/detection/Models_Well_trained/200epochs # 保存模型的地址
  save_epoch_step: 10 # 保存模型的间隔
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [3000, 2000]
  cal_metric_during_train: False
  pretrained_model: /civi/Chinese_license_plate_Note/detection/pretrain/ch_ppocr_server_v2.0_det_train/best_accuracy  # 预训练模型的地址
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: /civi/Chinese_license_plate_Note/detection/test_img/test6.22.png # 测试图片
  save_res_path: /civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epochs/det_db/predicts_db.txt

Architecture:
  model_type: det
  algorithm: DB
  Transform:
  Backbone:
    name: ResNet
    layers: 18
    disable_se: True
  Neck:
    name: DBFPN
    out_channels: 256
  Head:
    name: DBHead
    k: 50

Loss:
  name: DBLoss
  balance_loss: true
  main_loss_type: DiceLoss
  alpha: 5
  beta: 10
  ohem_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DBPostProcess
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DetMetric
  main_indicator: hmean

Train:
  dataset:
    name: SimpleDataSet
    data_dir: /civi/Chinese_license_plate_Note/detection/dataset/train_dataset/ # 训练集图片
    label_file_list:
      - /civi/Chinese_license_plate_Note/detection/dataset/train_label.txt # 训练集标签
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [960, 960]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8 # batchsize
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /civi/Chinese_license_plate_Note/detection/dataset/test_dataset/ # 验证集图片
    label_file_list:
      - /civi/Chinese_license_plate_Note/detection/dataset/test_label.txt # 验证集标签
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
#           image_shape: [736, 1280]
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 2

③ 开始训练

# 这里博主使用多卡训练,gpu编号为1和2

python3 -m paddle.distributed.launch --gpus '1,2' tools/train.py \
-c /civi/Chinese_license_plate_Note/detection/test.yml

4-- 推理测试

# 转推理模型

python3 tools/export_model.py -c /civi/Chinese_license_plate_Note/detection/test_0705.yml \
-o Global.pretrained_model="/civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epochs7.05/best_accuracy" \
Global.save_inference_dir="/civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epoches_tuili/"

# 推理测试

python3 tools/infer/predict_det.py --det_algorithm="DB" \
--det_model_dir="/civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epoches_tuili/" \
--image_dir="/civi/Chinese_license_plate_Note/detection/test_img/test2.jpeg" \
--use_gpu=True

笔记记录--基于ccpd数据集利用Paddle OCR训练车牌检测模型_第2张图片

 ## 未完待续

你可能感兴趣的:(paddle,python,机器学习)