PaddleOCR文字检测模型训练

PaddleOCR文字检测模型训练

本文档主要介绍PaddleOCR中文字检测模型的训练、评估及测试。

训练环境

  • CentOS 7
  • python3.7
  • paddlepaddle-gpu 2.0.0rc0

数据准备

自标注数据

  1. 将所有训练图片放在icdar_c4_train_imgs文件夹下,所有测试图片放在ch4_test_images目录下,以免多次标注产生多次修改代码问题。
  2. 将所有图片格式统一改成.jpg或者统一改成.png,目标为图片保持统一格式。
  3. 将标注生成的XML文件放在一个目录中,训练数据XML和测试数据XML分开。

然后执行下面的代码,把其中的xml_dir、train_file_name、train_file_label_name替换成自己的参数,执行完后会在train_file_name目录下生成一个train_file_label_name的txt文件,即为label文件。label文件与图片一起组成paddleOCR数据集。

#!/usr/local/env python3
# -*- coding: utf-8 -*-
"""
Auther: BurningSilence
date: 2020/11/16 下午5:13

DESC:
"""
import os
import xml.etree.ElementTree as et
import math


def edit_xml(xml_file, train_file_name):
    """
    VOC转换为PaddleOCR label
    :param xml_file:xml文件的路径
    :return: xml文件对应的label
    """
    x0 = y0 = x1 = y1 = x2 = y2 = x3 = y3 = 0
    tree = et.parse(xml_file)
    root = tree.getroot()
    root.attrib = None
    img_name = root.find("path").text.split("/")[-1]
    objs = root.findall('object')
    transcription_arr = []
    for obj in objs:
        dict_bak = {}
        points_arr = []
        obj_type = obj.find('type')
        name = obj.find('name')
        type = obj_type.text
        if type == 'bndbox':
            obj_bnd = obj.find('bndbox')
            xmin = int(float(obj_bnd.find('xmin').text))
            ymin = int(float(obj_bnd.find('ymin').text))
            xmax = int(float(obj_bnd.find('xmax').text))
            ymax = int(float(obj_bnd.find('ymax').text))

            x0, y0 = xmin, ymin
            x1, y1 = xmax, ymin
            x2, y2 = xmin, ymax
            x3, y3 = xmax, ymax
        elif type == 'robndbox':
            obj_bnd = obj.find('robndbox')
            cx = float(obj_bnd.find('cx').text)
            cy = float(obj_bnd.find('cy').text)
            w = float(obj_bnd.find('w').text)
            h = float(obj_bnd.find('h').text)
            angle = float(obj_bnd.find('angle').text)

            x0, y0 = rotate_point(cx, cy, cx - w / 2, cy - h / 2, -angle)
            x1, y1 = rotate_point(cx, cy, cx + w / 2, cy - h / 2, -angle)
            x2, y2 = rotate_point(cx, cy, cx + w / 2, cy + h / 2, -angle)
            x3, y3 = rotate_point(cx, cy, cx - w / 2, cy + h / 2, -angle)

        points_arr.append([x0, y0])
        points_arr.append([x1, y1])
        points_arr.append([x2, y2])
        points_arr.append([x3, y3])
        dict_bak["transcription"] = name.text
        dict_bak["points"] = points_arr
        transcription_arr.append(dict_bak)
    _img_label = train_file_name + "/" + img_name + " " + str(transcription_arr)
    print(_img_label)
    return _img_label


# 旋转后的四点坐标
def rotate_point(cx, cy, x_no_angle, y_no_angle, angle):
    x_off = x_no_angle - cx
    y_off = y_no_angle - cy
    cos_angle = math.cos(angle)
    sin_angle = math.sin(angle)
    cx_bak = cos_angle * x_off + sin_angle * y_off
    cy_bak = - sin_angle * x_off + cos_angle * y_off
    return int(cx + cx_bak), int(cy + cy_bak)


if __name__ == '__main__':
    # XML文档路径
    xml_dir = "/Users/andy/workspace/project/16-PaddleOCR/XML/"
    # 训练数据所在文件夹名称
    train_file_name = "icdar_c4_train_imgs"
    # 生成的label文件名
    train_file_label_name = "train_icdar2015_label.txt"
    label = ""
    for file in os.listdir(xml_dir):
        if file.endswith("xml"):
            img_label = edit_xml(os.path.join(xml_dir, file), train_file_name)
            label += img_label + "\n"
    with open(xml_dir + train_file_label_name, 'w') as f:
        f.write(label)

最终数据集目录如下

/PaddleOCR/train_data/icdar2015/text_localization/
  └─ icdar_c4_train_imgs/         icdar数据集的训练数据
  └─ ch4_test_images/             icdar数据集的测试数据
  └─ train_icdar2015_label.txt    icdar数据集的训练标注
  └─ test_icdar2015_label.txt     icdar数据集的测试标注

ICDAR2019-LSVT数据集

icdar2019数据集可以从官网下载到,首次下载需注册。

数据简介: 共45w中文街景图像,包含5w(2w测试+3w训练)全标注数据(文本坐标+文本内容),40w弱标注数据(仅文本内容),如下图所示:

PaddleOCR文字检测模型训练_第1张图片

(a) 全标注数据

PaddleOCR文字检测模型训练_第2张图片

(b) 弱标注数据

由于弱标注数据没有文本的坐标,本文档暂时只用全标注数据,将其分为训练集和测试集(可自行分配比例),本文档分配比例8:2。

将下载的train_full_images_0.tar.gz、train_full_images_1.tar.gz解压后,分配比例放在icdar2019_train_imgs和icdar2019_test_imgs目录下,然后按下面的代码改写train_full_labels.json

#!/usr/local/env python3
# -*- coding: utf-8 -*-
"""
Auther: BurningSilence
date: 2020/11/20 上午9:49

DESC:将icdar2019数据集label转换为PaddleOCR label
"""
import json

train_imgs_path = "icdar2019_train_imgs/"

old_label_json = "train_full_labels.json"
new_label_txt = "train_icdar2019_label.txt"
with open(old_label_json, 'r', encoding='utf-8') as fr:
    with open(new_label_txt, 'w') as fw:
        for key, value in json.load(fr).items():
            fw.writelines(train_imgs_path + key + ".jpg\t" + str(value) + "\n")

并且把ppocr/data/det/db_process.py以及tools/eval_utils/eval_det_utils.py文件中的

label = json.loads(substr[1])
# 修改为
label = eval(substr[1])

因为json.loads不能识别单引号。
最终数据集目录如下

/PaddleOCR/train_data/icdar2015/text_localization/
  └─ icdar2019_train_imgs/        icdar2019数据集的训练数据
  └─ icdar2019_test_imgs/         icdar2019数据集的测试数据
  └─ train_icdar2019_label.txt    icdar2019数据集的训练标注
  └─ test_icdar2019_label.txt     icdar2019数据集的测试标注

快速启动训练

下载预训练模型

首先下载模型backbone的pretrain model,本次以MobileNetV3模型为例,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列, 您可以根据需求使用PaddleClas中的模型更换backbone(骨架网络)。

cd PaddleOCR/
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
# 或,下载ResNet18_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
# 或,下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar

# 解压预训练模型文件,以MobileNetV3为例
tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/

# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下:
./pretrain_models/MobileNetV3_large_x0_5_pretrained/
  └─ conv_last_bn_mean
  └─ conv_last_bn_offset
  └─ conv_last_bn_scale
  └─ conv_last_bn_variance
  └─ ......

启动训练

如果您安装的是cpu版本,请将配置文件(./configs/det/det_mv3_db_icdar2019_v1.1.yml)中的 use_gpu 字段修改为false。本文档使用的是GPU进行训练(第1和第2块GPU),所以需要设置环境变量(在命令行中进行训练时设置):

export CUDA_VISIBLE_DEVICES=1,2

如果使用Pycharm训练的则需在Pycharm设置,具体设置方法为在菜单Edit->Run configurations->Environment variables添加变量的name和value即可。

命令行训练命令:

python3 tools/train.py -c configs/det/det_mv3_db_icdar2019_v1.1.yml \
     -o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/ \
     2>&1 | tee train_det_icdar2019.log

PS:

  1. det_mv3_db_icdar2019_v1.1.yml 可根据 det_mv3_db_v1.1.yml改写。

断点训练

如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:

python3 tools/train.py -c configs/det/det_mv3_db_icdar2019_v1.1.yml -o Global.checkpoints=output/det_db_icdar2019/best_accuracy

注意:Global.checkpoints的优先级高于Global.pretrain_weights的优先级,即同时指定两个参数时,优先加载Global.checkpoints指定的模型,如果Global.checkpoints指定的模型路径有误,会加载Global.pretrain_weights指定的模型。

指标评估

PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。

运行如下代码,根据配置文件det_db_mv3_v1.1.yml中save_res_path指定的测试集检测结果文件,计算评估指标。

评估时设置后处理参数box_thresh=0.6,unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化。

python3 tools/eval.py -c configs/det/det_mv3_db_icdar2019_v1.1.yml  -o Global.checkpoints="./output/det_db_icdar2019/iter_epoch_800" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5

上面800为训练的epoch数。

测试检测效果

测试单张图像的检测效果

python3 tools/infer_det.py -c configs/det/det_mv3_db_icdar2019_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db_icdar2019/iter_epoch_800"

测试DB模型时,调整后处理阈值,

python3 tools/infer_det.py -c configs/det/det_mv3_db_icdar2019_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db_icdar2019/iter_epoch_800" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5

测试文件夹下所有图像的检测效果

python3 tools/infer_det.py -c configs/det/det_mv3_db_icdar2019_v1.1.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db_icdar2019/iter_epoch_800"

PS

如果训练中遇到这种问题

json.decoder.JSONDecodeError: Expecting property name enclosed in double quotes: line 1 column 10 (char 9)

原因可能是数据集格式问题,其实主要是PaddleOCR代码中用的是json可以参考:python json常用方法总结,loads与dumps区别,load与dump区别
参考:
文字检测

你可能感兴趣的:(PaddleOCR,paddlepaddle,ocr)