PaddleOCR学习(二)PaddleOCR检测模型训练

这一部分主要介绍,如何使用自己的数据库去训练PaddleOCR的文本检测模型。

官方教程https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md

一、准备训练数据

首先你需要有自己的数据,如果没有自己的数据,推荐使用ICDAR2015的数据库,上网搜即可找到,内含1000个训练样本和500个测试样本,包括图片与标准数据(txt格式)。

如何标注自己的数据大家可以自行去网上搜索一下,PaddleOCR自带标注工具PPOCRLabel:https://github.com/PaddlePaddle/PaddleOCR/tree/develop/PPOCRLabel

不过因为我不是用PPOCRLabel进行的标注,而是采用了另一种更麻烦的方法进行标注,所以这里就不班门弄斧了,如果使用PPOCRLabel的过程中出了问题,也可以考虑采用我的方法:

(1)首先由于我的数据中涉及到了倾斜文本(弯曲文本我还没有了解过有没有什么特别好的检测模型,目前主流的检测模型可能也只到倾斜文本),所以我使用的是roLabelImg工具进行的标注;

(2)使用rolabelImg工具标注图片获得倾斜文本框,输出xml文件;

(3)将xml文件转换为txt文件,具体转换算法我放在本文最后xmltotxt.py:

需要注意的是,txt中的内容格式应该是: x 1 , y 1 , x 2 , y 2 , x 3 , y 3 , x 4 , y 4 , t e x t x_1,y_1,x_2,y_2,x_3,y_3,x_4,y_4,text x1,y1,x2,y2,x3,y3,x4,y4,text。对于roLablelImg标注的数据,角点坐标都保留两位小数,但是PaddleOCR中是按整数进行的计算,所以后面需要一点细微的修改。

此时获得,一张图片对应一个标注txt文件中的内容应该像以下内容:
PaddleOCR学习(二)PaddleOCR检测模型训练_第1张图片
(4)现在获得的应该是一个包含所有图片的文件夹与一个包含相同数量与图片同名txt文件的文件夹,接下来需要将该文件夹先分成训练用样本和测试用样本,为了后续方便,先新建以下结构的文件夹:

PaddleOCR学习(二)PaddleOCR检测模型训练_第2张图片
DatasetRes是我自己的数据集的名字,将标注好的数据按一定比例分别放进train_imgs和test_imgs中(具体的比例不好说,我也是新手,我觉得可以参考ICDAR的比例,训练:测试=2:1)。

然后,打开train_data/gen_label.py,修改其中的模式、图片路径、标注路径、输出结果路径:

gen_label的效果是,将所有标注txt,总合成一个总的txt文件,记得分别对测试数据和训练数据运行gen_label,获得两个label.txt文件。

切记,输出完之后,尽量不要修改文件夹或者txt文件的名称。

parser.add_argument(
        '--mode',
        type=str,
        default="det",   # 模式
        help='Generate rec_label or det_label, can be set rec or det')
    parser.add_argument(
        '--root_path',
        type=str,
        default="DatasetRes/test_imgs/",   # 图片
        help='The root directory of images.Only takes effect when mode=det ')
    parser.add_argument(
        '--input_path',
        type=str,
        default="DatasetRes/test_txts/",   # 标注
        help='Input_label or input path to be converted')
    parser.add_argument(
        '--output_label',
        type=str,
        default="DatasetRes/test_label.txt",  # 输出结果
        help='Output file name')

另外,gen_label.py中还有两个可能会坑人的地方,都在gen_det_label()函数中,一个是paddleocr对坐标的读取是int类型,如果使用roLabelImg标注,一般获得的是浮点类型的;另一点是gen_det_label()函数在读取文件名时,会自动把文件名的前三位忽视掉(不知道为什么,可能和不同方法获得的标注结果有关,总之会引起错误)。我把修改过的代码放在下面了。

def gen_det_label(root_path, input_dir, out_label):
    with open(out_label, 'w') as out_file:
        for label_file in os.listdir(input_dir):
            img_path = root_path + label_file[:-4] + ".jpg"      # 原先是label_file[3:-4]
            label = []
            with open(os.path.join(input_dir, label_file), 'r') as f:
                for line in f.readlines():
                    tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
                                                     "").split(',')
                    points = tmp[:8]
                    s = []
                    for i in range(0, len(points), 2):
                        b = points[i:i + 2]
                        b = [int(float(t)) for t in b]     # 原来是b=[int(t) for t in b],无法读取小数
                        s.append(b)
                    result = {"transcription": tmp[8], "points": s}
                    label.append(result)

            out_file.write(img_path + '\t' + json.dumps(
                label, ensure_ascii=False) + '\n')

如此,就把paddleocr检测模型训练需要的数据集准备好了。总的label.txt文件的内容大致像以下这样:
PaddleOCR学习(二)PaddleOCR检测模型训练_第3张图片

二、使用自己的数据集训练检测模型

终于把数据集准备好了,接下来就可以准备开始训练模型了,训练模型用到的是tools/train.py文件,不过没什么需要在这里面修改的。

首先,官方提供了三个backbone预训练模型,分别是MobileNetV3,ResNet8_vd,ResNet50_vd
https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar

非常好懂,就是ResNet50_vd非常非常大,没有四块以上GPU建议就不要尝试了。

新建pretrain_models/detect_pretrain_models文件夹,然后将下载的预训练模型解压到detect_pretrain_models下。
在这里插入图片描述
如果你去看教程,他会告诉你运行以下命令,然后你就会一脸懵逼发现什么都没有发生,所以我觉得还是需要再详细解释一下。

python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml \
     -o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/ \
     2>&1 | tee train_det.log

实际最后运行的指令应该像这样即可,记得在cmd或者anaconda prompt中cd到paddleocr-develop目录下执行:

python tools/train.py -c configs/det/det_r18_vd_db_v1.1.yml 2>&1 | tee train_det.log

重点,在运行该指令前,打开configs/det/det_r18_vd_db_v1.1.yml进行修改。

# det_r18_vd_db_v1.1.yml

Global:
  algorithm: DB     # 使用的文本检测算法,这里用的是DB,我后来用的east,我将r18对应east的yml文件放在本文最后
  use_gpu: true
  epoch_num: 1200
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: ./output/det_r18_vd_db/     # 训练好的模型输出位置
  save_epoch_step: 200
  eval_batch_step: [3000, 2000]
  train_batch_size_per_card: 8
  test_batch_size_per_card: 1
  image_shape: [3, 640, 640]
  reader_yml: ./configs/det/det_db_icdar15_reader.yml       # 记住这个文件,接下来就要改它
  pretrain_weights: ./pretrain_models/detect_pretrain_models/ResNet18_vd_pretrained/  # 预训练模型的保存路径
  save_res_path: ./output/det_r18_vd_db/predicts_db.txt     # 预测结果文件的保存路径
  checkpoints:
  save_inference_dir:
  infer_img:
# det_db_icdar15_reader.yml

TrainReader:
  reader_function: ppocr.data.det.dataset_traversal,TrainReader
  process_function: ppocr.data.det.east_process,EASTProcessTrain
  num_workers: 4 # 量力而行,看自己电脑配置
  img_set_dir: ./train_data/ # 记得只要写这么长就行了,label.txt文件中,图片文件名包含了DatasetRes/train_imgs/xxx.jpg
  label_file_path: ./train_data/DatasetReal/train_label.txt  # 刚才gen_label保存的文件路径
  background_ratio: 0.125
  min_crop_side_ratio: 0.1
  min_text_size: 10

EvalReader:
  reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
  process_function: ppocr.data.det.east_process,EASTProcessTest
  img_set_dir: ./train_data/ # 同理
  label_file_path: ./train_data/DatasetReal/test_label.txt  # 同理
  
TestReader:
  reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
  process_function: ppocr.data.det.east_process,EASTProcessTest
  img_set_dir: ./train_data/   # 同理
  label_file_path: ./train_data/DatasetReal/test_label.txt   # 同理
  do_eval: True

好了,都改好了,可以执行刚才的命令了:

python tools/train.py -c configs/det/det_r18_vd_db_v1.1.yml 2>&1 | tee train_det.log

训练时会将训练过程打印到train_det.log文件。
PaddleOCR学习(二)PaddleOCR检测模型训练_第4张图片

三、整理、评估训练结果

模型训练完之后,到det_r18_vd_db_v1.1.yml文件中的save_model_dir: ./output/det_r18_vd_db/位置去找训练结果,像这样:
PaddleOCR学习(二)PaddleOCR检测模型训练_第5张图片
具体每多少epoch输出一次可以在yml文件中设置,不多赘述。

接下来需要将模型转换为可部署文件,在paddleocr-develop目录下运行指令:

python tools/export_model.py -c configs/det/det_r18_vd_db_v1.1.yml 
                              -o Global.checkpoints="./output/det_r18_vd_db/best_accuracy" 
                                 Global.save_inference_dir="./output/det_r18_vd_db/export_model"

记得根据自己的保存路径进行修改。./output/my_det_r18_vd_db/export_model中应该有两个文件:model和params。

如果训练程序中途断了,希望加载训练中断的模型继续训练,可以通过如下指令:

python tools/train.py -c configs/det/det_r18_vd_db_v1.1.yml 
                      -o Global.checkpoints="./output/det_r18_vd_db/best_accuracy"

好了,现在有了模型,如何评估模型的有效性可以自己去搜索学习一下,对于目标检测类算法,需要计算Precision、Recall、Hmean,运行以下代码即可:

python tools/eval.py -c configs/det/det_r18_vd_db_v1.1.yml 
                     -o Global.checkpoints="./output/det_r18_vd_db/best_accuracy"
                     PostProcess.box_thresh=0.6 
                     PostProcess.unclip_ratio=1.5

即可输出该模型的Precision、Recall、Hmean。

这里需要注意,上述指令是针对DB算法,如果你用的不是DB算法,而是EAST算法,指令需要有所不同,主要是在PostProcess中,EAST和DB的PostProcess的参数不同,所以进行评估时也需要输入不同的参数。如果是EAST算法,指令为:

python tools/eval.py -c configs/det/det_r18_east.yml 
                     -o Global.checkpoints="./output/det_east/best_accuracy"  # 自行注意文件夹的不同
                     PostProcess.score_thresh=0.8 
                     PostProcess.cover_thresh=0.1
                     PostProcess.nms_thresh=0.2

最后是用训练好的模型去测试自己的图片看效果,在PaddleOCR学习(一)PaddleOCR安装与测试中我已经介绍过如何调用模型进行图片检测,只要将其中的det_model_dir的默认路径改到./output/det_r18_vd_db/export_model/即可。

不过其实,如果不输出成可部署文件,也可以直接进行图片测试,运行以下指令:

python tools/infer_det.py -c configs/det/det_r18_vd_db_v1.1.yml 
                          -o Global.infer_img="./doc/imgs_en/img_10.jpg" 
                             Global.checkpoints="./output/det_east/best_accuracy"

或者一次性测试一整个文件夹:

python tools/infer_det.py -c configs/det/det_r18_vd_db_v1.1.yml  
                          -o Global.infer_img="./doc/imgs_en/" 
                             Global.checkpoints="./output/det_east/best_accuracy"

还可以在测试过程中调整后处理阈值

python tools/infer_det.py -c configs/det/det_r18_vd_db_v1.1.yml 
                          -o Global.infer_img="./doc/imgs_en/img_10.jpg" 
                             Global.checkpoints="./output/det_east/best_accuracy"
                          PostProcess.box_thresh=0.6 
                          PostProcess.unclip_ratio=1.5

OK,至此检测模型训练完毕,至于如何调参获取更好的训练结果,我也想知道(–_--)

附件

# xmltotxt.py

# coding=utf-8

import os
import xml.dom.minidom
import cv2 as cv
import math


def xml_to_txt(indir, outdir):
    os.chdir(indir)
    xmls = os.listdir('.')
    for i, file in enumerate(xmls):
        file_save = file.split('.')[0] + '.txt'
        file_txt = os.path.join(outdir, file_save)
        f_w = open(file_txt, 'w')
        # actual parsing
        DOMTree = xml.dom.minidom.parse(file)
        annotation = DOMTree.documentElement
        filename = annotation.getElementsByTagName("path")[0]
        imgname = filename.childNodes[0].data
        img_temp = imgname.split('\\')[-1]
        img_temp = os.path.join(image_dir, img_temp)
        image = cv.imread(imgname)
#        cv.imwrite(img_temp, image)
        objects = annotation.getElementsByTagName("object")
        print(file)
        for object in objects:
            bbox = object.getElementsByTagName("robndbox")[0]
            cx = bbox.getElementsByTagName("cx")[0]
            x = float(cx.childNodes[0].data)
            print(x)
            cy = bbox.getElementsByTagName("cy")[0]
            y = float(cy.childNodes[0].data)
            print(y)
            cw = bbox.getElementsByTagName("w")[0]
            w = float(cw.childNodes[0].data)
            print(w)
            ch = bbox.getElementsByTagName("h")[0]
            h = float(ch.childNodes[0].data)
            print(h)
            cangel = bbox.getElementsByTagName("angle")[0]
            angle = float(cangel.childNodes[0].data)
            print(angle)
            cname = object.getElementsByTagName("name")[0]
            name = cname.childNodes[0].data
            print(name)
            x1, y1 = rotatePoint(x, y, x - w / 2, y - h / 2, -angle)
            x2, y2 = rotatePoint(x, y, x + w / 2, y - h / 2, -angle)
            x3, y3 = rotatePoint(x, y, x + w / 2, y + h / 2, -angle)
            x4, y4 = rotatePoint(x, y, x - w / 2, y + h / 2, -angle)
            temp = str('%.2f' % x1) + ',' + str('%.2f' % y1) + ',' + str('%.2f' % x2) + ',' + str('%.2f' % y2) + ',' + \
                   str('%.2f' % x3) + ',' + str('%.2f' % y3) + ',' + \
                   str('%.2f' % x4) + ',' + str('%.2f' % y4) + ',' + name + '\n'
            f_w.write(temp)
        f_w.close()

# 转换成四点坐标
def rotatePoint(xc, yc, xp, yp, theta):
    xoff = xp - xc;
    yoff = yp - yc;
    cosTheta = math.cos(theta)
    sinTheta = math.sin(theta)
    pResx = cosTheta * xoff + sinTheta * yoff
    pResy = - sinTheta * xoff + cosTheta * yoff
    return xc + pResx, yc + pResy


if __name__ == '__main__':
    image_dir = "./origin_png"  # img目录
    indir = "./xml"  # xml目录
    outdir = "./txt"
    xml_to_txt(indir, outdir)
# det_r18_vd_east.yml

Global:
  algorithm: EAST   # EAST算法是目前比较优秀的文本检测算法
  use_gpu: true
  epoch_num: 1000
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: ./output/det_east_real/
  save_epoch_step: 200
  eval_batch_step: [3000, 2000]
  train_batch_size_per_card: 8
  test_batch_size_per_card: 1
  image_shape: [3, 512, 512]
  reader_yml: ./configs/det/det_east_icdar15_reader.yml
  pretrain_weights: ./pretrain_models/detect_pretrain_models/ResNet18_vd_pretrained/
  save_res_path: ./output/det_east_real/predicts_east.txt
  checkpoints:
  save_inference_dir:
  infer_img:

Architecture:
  function: ppocr.modeling.architectures.det_model,DetModel

Backbone:
  function: ppocr.modeling.backbones.det_resnet_vd,ResNet
  layers: 18

Head:
  function: ppocr.modeling.heads.det_east_head,EASTHead
  model_name: large
  
Loss:
  function: ppocr.modeling.losses.det_east_loss,EASTLoss

Optimizer:
  function: ppocr.optimizer,AdamDecay
  base_lr: 0.001
  beta1: 0.9
  beta2: 0.999

PostProcess:
  function: ppocr.postprocess.east_postprocess,EASTPostPocess
  score_thresh: 0.8       # 记住这几个参数,后面有用
  cover_thresh: 0.1
  nms_thresh: 0.2

你可能感兴趣的:(神经网络学习,python)