Paddle OCR文字检测(三)(微调)

本篇基于PaddleOCR开源套件,以PP-OCRv3检测和识别模型为基础,针对液晶屏读数识别场景进行优化。

Aistudio项目链接:OCR液晶屏读数识别

1.PP-OCRv3检测算法介绍

PP-OCRv3检测模型是对PP-OCRv2中的CML(Collaborative Mutual Learning) 协同互学习文本检测蒸馏策略进行了升级。如下图所示,CML的核心思想结合了①传统的Teacher指导Student的标准蒸馏与 ②Students网络之间的DML互学习,可以让Students网络互学习的同时,Teacher网络予以指导。PP-OCRv3分别针对教师模型和学生模型进行进一步效果优化。其中,在对教师模型优化时,提出了大感受野的PAN结构LK-PAN和引入了DML(Deep Mutual Learning)蒸馏策略;在对学生模型优化时,提出了残差注意力机制的FPN结构RSE-FPN。

Paddle OCR文字检测(三)(微调)_第1张图片

详细优化策略描述请参考PP-OCRv3优化策略

2.数据准备

aie2023数据来源于实际项目中各种计量设备的数显屏,以及在网上搜集的一些其他数显屏,包含训练集348张,验证集116,测试集115张。

①将数据集aie2023放到PaddleOCR下文件夹train_data中

# 随机查看文字检测数据集图片
from PIL import Image  
import matplotlib.pyplot as plt
import numpy as np
import os

train = './train_data/aie2023/det/test'
# 从指定目录中选取一张图片
def get_one_image(train):
    plt.figure()
    files = os.listdir(train)
    n = len(files)
    ind = np.random.randint(0,n)
    img_dir = os.path.join(train,files[ind])  
    image = Image.open(img_dir)  
    plt.imshow(image)
    plt.show()
    image = image.resize([208, 208])  

get_one_image(train)  
3. 模型训练
模型微调基本步骤:
(创建模型配置文件,下载预训练模型)→垂直数据集评估预训练模型→垂直数据集微调预训练模型→评估微调后的模型
①创建训练配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml

(本人在configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml上修改重命名为configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml)

configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml的内容如下:

Global:
  debug: false
  use_gpu: true
  epoch_num: 500
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/ch_PP-OCR_v3_det_aie2023/
  save_epoch_step: 100
  eval_batch_step:
  - 0
  - 400
  cal_metric_during_train: false
  pretrained_model: null
  checkpoints: null
  save_inference_dir: null
  use_visualdl: false
  infer_img: train_data/aie2023/det/test/65_25_0197-01S-S200.bmp
  save_res_path: ./checkpoints/det_db/predicts_db_aie2023.txt
  distributed: true

Architecture:
  name: DistillationModel
  algorithm: Distillation
  model_type: det
  Models:
    Student:
      pretrained:
      model_type: det
      algorithm: DB
      Transform: null
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: true
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    Student2:
      pretrained:
      model_type: det
      algorithm: DB
      Transform: null
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: true
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    Teacher:
      freeze_params: true
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: ResNet_vd
        in_channels: 3
        layers: 50
      Neck:
        name: LKPAN
        out_channels: 256
      Head:
        name: DBHead
        kernel_list: [7,2,2]
        k: 50

Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: L2
    factor: 5.0e-05

PostProcess:
  name: DistillationDBPostProcess
  model_name: ["Student"]
  key: head_out
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/aie2023/det/
    label_file_list:
      - ./train_data/aie2023/det/train.txt
    ratio_list: [1.0]
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - DetLabelEncode: null
    - CopyPaste:
    - IaaAugment:
        augmenter_args:
        - type: Fliplr
          args:
            p: 0.5
        - type: Affine
          args:
            rotate:
            - -10
            - 10
        - type: Resize
          args:
            size:
            - 0.5
            - 3
    - EastRandomCropData:
        size:
        - 960
        - 960
        max_tries: 50
        keep_ratio: true
    - MakeBorderMap:
        shrink_ratio: 0.4
        thresh_min: 0.3
        thresh_max: 0.7
    - MakeShrinkMap:
        shrink_ratio: 0.4
        min_text_size: 8
    - NormalizeImage:
        scale: 1./255.
        mean:
        - 0.485
        - 0.456
        - 0.406
        std:
        - 0.229
        - 0.224
        - 0.225
        order: hwc
    - ToCHWImage: null
    - KeepKeys:
        keep_keys:
        - image
        - threshold_map
        - threshold_mask
        - shrink_map
        - shrink_mask
  loader:
    shuffle: true
    drop_last: false
    batch_size_per_card: 8
    num_workers: 0

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/aie2023/det/
    label_file_list:
      - ./train_data/aie2023/det/val.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 0
②预训练模型直接评估

下载我们需要的PP-OCRv3检测预训练模型,更多选择请自行选择其他的文字检测模型

#使用该指令下载需要的预训练模型
wget-P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
# 解压预训练模型文件
tar-xf ./pretrained_models/ch_PP-OCRv3_det_distill_train.tar-Cpretrained_models
Paddle OCR文字检测(三)(微调)_第2张图片

在训练之前,我们可以直接使用下面命令来评估预训练模型的效果:

# 评估预训练模型
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy"
例如:
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy"

PP-OCRv3检测预训练模型在数据集aie2023上的评估结果如下:

Paddle OCR文字检测(三)(微调)_第3张图片
③ 预训练模型直接finetune

a.修改配置文件:我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方:

epoch:100
save_epoch_step:10
eval_batch_step:[0, 50]
save_model_dir: ./output/ch_PP-OCR_v3_det/
pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy
learning_rate: 0.00025
num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错

b.开始训练:使用我们上面修改的配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml,训练命令如下:

# 开始训练模型
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy

评估训练好的模型

# 评估训练好的模型
pythontools/eval.py-cconfigs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml_aie2023.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_aie2023/best_accuracy"
详情请参考:

基于PP-OCRv3的液晶屏读数检测

你可能感兴趣的:(OCR,python,计算机视觉,目标检测,paddle)