mmocr 测试字符检测和识别模型

mmocr 测试字符检测和识别模型_第1张图片

mmocr 测试字符检测和识别模型_第2张图片

注:左图是原图,右图是定位和识别结果。

可以兼容多种模型。

import os.path

import cv2
import torch

from mmocr.apis.inference import model_inference
from mmocr.apis import init_detector
from mmocr.utils import revert_sync_batchnorm
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.core.visualize import det_recog_show_result

textdet_models = {
            'DB_r18': {
                'config':
                'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
                'ckpt':
                'dbnet/'
                'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
            },
            'DB_r50': {
                'config':
                'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
                'ckpt':
                'dbnet/'
                'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth'
            },
            'DRRG': {
                'config':
                'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
                'ckpt':
                'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth'
            },
            'FCE_IC15': {
                'config':
                'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
                'ckpt':
                'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth'
            },
            'FCE_CTW_DCNv2': {
                'config':
                'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
                'ckpt':
                'fcenet/' +
                'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth'
            },
            'MaskRCNN_CTW': {
                'config':
                'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
                'ckpt':
                'maskrcnn/'
                'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
            },
            'MaskRCNN_IC15': {
                'config':
                'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
                'ckpt':
                'maskrcnn/'
                'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
            },
            'MaskRCNN_IC17': {
                'config':
                'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
                'ckpt':
                'maskrcnn/'
                'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
            },
            'PANet_CTW': {
                'config':
                'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
                'ckpt':
                'panet/'
                'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
            },
            'PANet_IC15': {
                'config':
                'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
                'ckpt':
                'panet/'
                'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
            },
            'PS_CTW': {
                'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
                'ckpt':
                'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
            },
            'PS_IC15': {
                'config':
                'psenet/psenet_r50_fpnf_600e_icdar2015.py',
                'ckpt':
                'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
            },
            'TextSnake': {
                'config':
                'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
                'ckpt':
                'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
            },
            'Tesseract': {}
        }

textrecog_models = {
            'CRNN': {
                'config': 'crnn/crnn_academic_dataset.py',
                'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
            },
            'SAR': {
                'config': 'sar/sar_r31_parallel_decoder_academic.py',
                'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
            },
            'SAR_CN': {
                'config':
                'sar/sar_r31_parallel_decoder_chinese.py',
                'ckpt':
                'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth'
            },
            'NRTR_1/16-1/8': {
                'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
                'ckpt':
                'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
            },
            'NRTR_1/8-1/4': {
                'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
                'ckpt':
                'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
            },
            'RobustScanner': {
                'config': 'robust_scanner/robustscanner_r31_academic.py',
                'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth'
            },
            'SATRN': {
                'config': 'satrn/satrn_academic.py',
                'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth'
            },
            'SATRN_sm': {
                'config': 'satrn/satrn_small.py',
                'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth'
            },
            'ABINet': {
                'config': 'abinet/abinet_academic.py',
                'ckpt': 'abinet/abinet_academic-f718abf6.pth'
            },
            'SEG': {
                'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
                'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
            },
            'CRNN_TPS': {
                'config': 'tps/crnn_tps_academic_dataset.py',
                'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
            },
            'Tesseract': {}
        }


def single_inference(model, arrays, batch_mode, batch_size=0):
    def inference(m, a, **kwargs):
        return model_inference(m, a, **kwargs)

    result = []
    if batch_mode:
        if batch_size == 0:
            result = inference(model, arrays, batch_mode=True)
        else:
            n = batch_size
            arr_chunks = [
                arrays[i:i + n] for i in range(0, len(arrays), n)
            ]
            for chunk in arr_chunks:
                result.extend(inference(model, chunk, batch_mode=True))
    else:
        for arr in arrays:
            result.append(inference(model, arr, batch_mode=False))
    return result


def det_recog_kie_inference(det_model, recog_model, kie_model=None, imgs=None, batch_mode=False, recog_batch_size=0):
    end2end_res = []  # 所有图片的重要结果备份

    # 1, 定位(1orbatch)
    det_result = single_inference(det_model, imgs, batch_mode=batch_mode, batch_size=0)
    bboxes_list = [res['boundary_result'] for res in det_result]

    # 2, 识别(1orbatch)
    for img, bboxes in zip(imgs, bboxes_list):
        img_e2e_res = {}
        img_e2e_res['result'] = []  # 一张图片中检测到的所有文本位置,和内容

        box_imgs = []
        for bbox in bboxes:
            box_res = {}
            box_res['box'] = [round(x) for x in bbox[:-1]]  # 前8个是位置
            box_res['box_score'] = float(bbox[-1])  # 最后一个是概率
            box = bbox[:8]
            if len(bbox) > 9:  # 多边形bbox不止8个位置,返回的是轮廓点集合(xy,xy,...)
                min_x = min(bbox[0:-1:2])  # 从第一个开始到最后一个,间隔为2,找最小的x
                min_y = min(bbox[1:-1:2])  # 从第二个开始到最后一个,间隔为2,找最小的y
                max_x = max(bbox[0:-1:2])
                max_y = max(bbox[1:-1:2])
                box = [   # 最小外接矩形。
                    min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
                ]
            # 2,裁剪
            box_img = crop_img(img, box)
            if batch_mode:
                box_imgs.append(box_img)  # 先打包成batch
            else:
                # 3,识别
                recog_result = model_inference(recog_model, box_img)
                text = recog_result['text']  # 文本内容
                text_score = recog_result['score']  # 概率
                if isinstance(text_score, list):
                    text_score = sum(text_score) / max(1, len(text))
                box_res['text'] = text
                box_res['text_score'] = text_score
            img_e2e_res['result'].append(box_res)

        if batch_mode:
            recog_results = single_inference(recog_model, box_imgs, True, recog_batch_size)
            for i, recog_result in enumerate(recog_results):
                text = recog_result['text']
                text_score = recog_result['score']
                if isinstance(text_score, (list, tuple)):
                    text_score = sum(text_score) / max(1, len(text))
                img_e2e_res['result'][i]['text'] = text
                img_e2e_res['result'][i]['text_score'] = text_score
        end2end_res.append(img_e2e_res)
    return end2end_res


# Post processing function for end2end ocr
def det_recog_pp(arrays, result, outputs=[None], imshow=True):
    final_results = []
    for arr, output, det_recog_result in zip(arrays, outputs, result):
        if output or imshow:
            res_img = det_recog_show_result(arr, det_recog_result, out_file=output)
            if imshow:
                cv2.namedWindow("inference results", cv2.WINDOW_NORMAL), cv2.imshow("inference results", res_img)
                cv2.waitKey()
                # mmcv.imshow(res_img, 'inference results')


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    img_dir = r"E:\data\ocr\text_detection\ctw1500\imgs\training"
    img_name_list = os.listdir(img_dir)
    for img_name in img_name_list:
        # img
        img_path = os.path.join(img_dir, img_name)
        img = cv2.imread(img_path)
        img = cv2.resize(img, (640, 640))
        # detection
        config_dir = r'D:\\code\\python\\mmocr\\configs'
        td = "DB_r50"
        det_ckpt = r"D:\code\python\mmocr\tools\work_dirs\dbnet_r50_0415\epoch_80.pth"
        det_config = os.path.join(config_dir, "textdet/", textdet_models[td]['config'])
        if not det_ckpt:
            det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + textdet_models[td]['ckpt']
        detect_model = init_detector(det_config, det_ckpt, device=device)
        detect_model = revert_sync_batchnorm(detect_model)

        # recognition
        tr = "SEG"
        recog_ckpt = None
        recog_config = os.path.join(config_dir, "textrecog/", textrecog_models[tr]['config'])
        if not recog_ckpt:
            recog_ckpt = 'https://download.openmmlab.com/mmocr/' + 'textrecog/' + textrecog_models[tr]['ckpt']
        recog_model = init_detector(recog_config, recog_ckpt, device=device)
        recog_model = revert_sync_batchnorm(recog_model)

        # Attribute check
        for model in list(filter(None, [recog_model, detect_model])):
            if hasattr(model, 'module'):
                model = model.module

        det_recog_result = det_recog_kie_inference(
            detect_model, recog_model, kie_model=None, imgs=[img], batch_mode=False)
        det_recog_pp([img], det_recog_result, outputs=[None], imshow=True)

其中:

det_ckpt = r"D:\code\python\mmocr\tools\work_dirs\dbnet_r50_0415\epoch_80.pth"

这是我换成自己训练的dbnet字符检测权重。

你可能感兴趣的:(mmocr,计算机视觉,OCR,深度学习)