玩玩paddle ocr

paddle ocr demo封装

前言

对于ocr一直比较感兴趣,随便看了下好像训练开源的比较少,百度还是比较良心的,提供了保姆级的教程给开发者,但是发现整个项目好像没有提供一个可以修改的脚本,处于好玩的目的,clone了下paddle ocr,简单的写了个paddle ocr demo,记录一下,

1.配置文件

百度的ocr有三个模型:检测、识别、分类,原始的仓库上对齐有较为详细的说明。这里简单写到yaml配置文件里面

image_dir: ./imgs/11.jpg
det: true
rec: true
use_angle_cls: false

drop_score: 0.5


# DB parmas
det_algorithm:
 name: DB
 use_gpu: true
 model_dir: 2.0_model/det
 pre_process:
   DetResizeForTest: 
     limit_side_len: 960
     limit_type: max
   NormalizeImage:
     std: [0.229, 0.224, 0.225]
     mean: [0.485, 0.456, 0.406]
     scale: 1./255.
     order: hwc
   ToCHWImage:
   KeepKeys:
     keep_keys: ['image', 'shape']

 postprocess:
   name: DBPostProcess
   thresh: 0.4
   box_thresh: 0.5
   max_candidates: 1000
   unclip_ratio: 2.0
   use_dilation: True
lang:
 name: ch
 dict_path: ./ppocr/utils/ppocr_keys_v1.txt

rec_algorithm:
 name: CRNN
 use_gpu: true
 model_dir: 2.0_model/rec/ch
 rec_image_shape: "3, 32, 320"
 rec_char_type: ch
 rec_batch_num: 30
 max_text_length: 25
 postprocess:
   name: CTCLabelDecode
   character_dict_path: ./dict/ppocr_keys_v1.txt
   use_space_char: True

text_classifier:
 use_gpu: true
 model_dir: 2.0_model/cls
 cls_image_shape: 3, 48, 192
 label_list: ['0', '180']
 cls_batch_num: 30
 cls_thresh: 0.9
 postprocess:
   name: ClsPostProcess

2.类提取和封装

demo里面所有类实现均参考原始的仓库

class PaddleOCR(TextSystem):
    def __init__(self, args):
        """
        paddleocr package
        args:
            **kwargs: other params show in paddleocr --help
        """

        self.use_angle_cls = args['use_angle_cls']
        super().__init__(args)

    def ocr(self, img, det=True, rec=True, cls=False):
        """
        ocr with paddleocr
        args:
            img: img for ocr, support ndarray, img_path and list or ndarray
            det: use text detection or not, if false, only rec will be exec. default is True
            rec: use text recognition or not, if false, only det will be exec. default is True
        """
        if isinstance(img, np.ndarray) and len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        if det and rec:
            dt_boxes, rec_res = self.__call__(img)
            return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
        elif det and not rec:
            dt_boxes, elapse = self.text_detector(img)
            if dt_boxes is None:
                return None
            return [box.tolist() for box in dt_boxes]
        else:
            if not isinstance(img, list):
                img = [img]
            if self.use_angle_cls:
                img, cls_res, elapse = self.text_classifier(img)
                if not rec:
                    return cls_res
            rec_res, elapse = self.text_recognizer(img)
            return rec_res

3添加demo可视化

import argparse
import yaml
import cv2
import numpy as np
from PIL import ImageFont, ImageDraw, Image
from core.PaddleOCR import PaddleOCR

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--params", type=str, default='data/params.yaml')
    return parser.parse_args()

def draw_result(img, result):
    img_rec = np.ones_like(img, np.uint8)*255
    img_pil = Image.fromarray(img_rec)
    draw = ImageDraw.Draw(img_pil)
    fontpath = "font/simsun.ttc"
    font = ImageFont.truetype(fontpath, 16)
    for info in result:
        bbox, rec_info = info
        pts=np.array(bbox, np.int32)
        pts=pts.reshape((-1,1,2))
        cv2.polylines(img,[pts],True,(0,255,0),2)
        txt = rec_info[0] + str(rec_info[1])
        draw.text(tuple(pts[0][0]), txt, font=font, fill =(0,255,0))
    bk_img = np.array(img_pil)
    draw_img  = np.hstack([img,bk_img])
    return draw_img

if __name__ == '__main__':
    args = parse_args()
    

    with open(args.params) as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # data dict

    ocr_engine = PaddleOCR(data_dict)
    img = cv2.imread(data_dict['image_dir'])
    result = ocr_engine.ocr(img,
                            det=data_dict['det'],
                            rec=data_dict['rec'],
                            cls=data_dict['use_angle_cls'])
    draw_img = draw_result(img, result)
    cv2.imwrite('result.jpg', draw_img)
    cv2.imshow("img", draw_img)
    cv2.waitKey(0)

最后结果输出
玩玩paddle ocr_第1张图片

你可能感兴趣的:(framework,百度,ocr)