Paddle OCR文字检测(二)(微调)

实际使用中,官方建议加载官方预训练模型,在自己的数据集中微调。所以本篇写一下文本检测模型的微调方法。

1.背景和意义

PaddleOCR提供的PP-OCR系列模型在通用场景下表现优异,可以解决大部分情况下的检测识别问题。在垂直场景下,如果想获得更好的模型,可以通过fine-tune进一步提升PP-OCR系列检测识别模型的准确率。

本文主要介绍微调文本检测识别模型时的一些注意事项。最后,您可以在自己的场景中通过模型微调得到精度更高的文本检测和识别模型。

本文的核心观点如下:

  1. PP-OCR提供的预训练模型具有更好的泛化能力

  1. 加入少量真实数据(检测:>=500,识别:>=5000)会大大提升垂直场景的检测识别效果

  1. 在微调模型时,加入真实的通用场景数据可以进一步提高模型精度和泛化性能

  1. 在文本检测任务中,增加图像的预测形状可以进一步提高较小文本区域的检测效果

  1. 在微调模型时,需要适当调整超参数(学习率、batch size最重要)以获得更好的微调效果。

2.数据集
  • 数据集:建议至少准备500个文本检测数据集用于模型微调。

  • 数据集标注:单行文本标注格式,建议标注的检测框与实际语义内容一致。比如在火车票场景中,姓氏和名字可能相距很远,但在语义上属于同一个检测域。这里,还需要将整个名字标记为检测框。

3.型号

推荐选择PP-OCRv3模型(配置文件:ch_PP-OCRv3_det_student.yml,预训练模型:ch_PP-OCRv3_det_distill_train.tar),其准确率和泛化性能是目前最好的预训练模型。

更多PP-OCR系列机型请参考PP-OCR系列机型库。

注意:使用上述预训练模型时,需要使用student.pdparams文件夹中的文件作为预训练模型,即只使用student模型。

4.训练超参数

在微调模型时,最重要的超参数是预训练模型路径pretrained_model,learning_rate与batch_size,部分超参数如下:

Global:
  pretrained_model: ./ch_PP-OCRv3_det_distill_train/student.pdparams # pre-training model path
Optimizer:
  lr:
    name: Cosine
    learning_rate: 0.001 # learning_rate
    warmup_epoch: 2
  regularizer:
    name: 'L2'
    factor: 0
Train:
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8  # single gpu batch size
    num_workers: 4

在上面的配置文件中,需要指定pretrained_model字段作为student.pdparams文件路径。

PaddleOCR提供的配置文件是8-gpu训练(相当于总batch size为8*8=64),没有加载预训练模型。因此,在你的场景中,学习率和总的batch size一样,需要线性调整,例如

  • 如果你的场景是单gpu训练,单gpu batch_size=8,那么total batch_size=8,建议调整学习率到约 1e-4。

  • 如果你的场景是单gpu训练,由于内存限制,可以只对单gpu设置batch_size=4,总的batch_size=4。建议将学习率调整为 约 5e-5。

5. 预测超参数

在导出和推断训练好的模型时,可以进一步调整预测的图像比例,以提高小面积文本的检测效果。以下是DBNet推理时的一些超参数,可以适当调整以提高效果。

hyperparameter

type

default

meaning

det_db_thresh

float

0.3

在DB输出的概率图中,得分大于阈值的像素将被认为是文本像素

det_db_box_thresh

float

0.6

当检测结果帧内所有像素点的平均得分大于阈值时,该结果将被认为是文本区域

det_db_unclip_ratio

float

1.5

Vatti clipping 的扩展系数,使用该方法扩展文本区域

max_batch_size

int

10

batch size

use_dilation

bool

False

是否扩展分割结果以获得更好的检测结果

det_db_score_mode

str

"fast"

DB的检测结果分数计算方法支持fastslowfast计算多边形外接矩形边框内所有像素点slow的平均分,和原多边形内所有像素点计算平均分。计算速度相对较慢,但更准确。

有关推理方法的更多信息,请参阅Paddle Inference doc。

详情请参考:

模型微调教程https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_en/finetune_en.md

你可能感兴趣的:(OCR,paddle,深度学习,人工智能)