对paddleOCR中的字符识别模型转ONNX

对paddle OCR中的模型转换成ONNX。

转换代码:



import os
import sys
import yaml
import numpy as np
import cv2
import argparse
import paddle
from paddle import nn

from argparse import ArgumentParser, RawDescriptionHelpFormatter
import paddle.distributed as dist
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.modeling.architectures import build_model


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))

global_config = AttrDict()
default_config = {'Global': {'debug': False, }}

class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        # self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec_idcard.yml',
        #                   help="configuration file to use")

        self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec.yml',
                          help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config

def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    merge_config(default_config)
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
    return global_config

def check_device(use_gpu, use_xpu=False):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
          "model on CPU"

    try:
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
        if use_gpu and not paddle.is_compiled_with_cuda():
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
            sys.exit(1)
    except Exception as e:
        pass

def getArgs(is_train=False):
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']

    use_xpu = False

    alg = config['Architecture']['algorithm']
    assert alg in [
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
        'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
        'Gestalt', 'SLANet', 'RobustScanner'
    ]

    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    check_device(use_gpu, use_xpu)

    device = paddle.set_device(device)

    config['Global']['distributed'] = dist.get_world_size() != 1

    return config, device


class CRNN(nn.Layer):
    def __init__(self, config, device):
        super(CRNN, self).__init__()
        # 定义预处理参数
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        self.mean = paddle.to_tensor(mean).reshape([1, 3, 1, 1])
        self.std = paddle.to_tensor(std).reshape([1, 3, 1, 1])

        self.config = config
        # build post process
        self.post_process_class = build_post_process(config['PostProcess'],
                                                     config['Global'])
        # build model
        if hasattr(self.post_process_class, 'character'):
            char_num = len(getattr(self.post_process_class, 'character'))
            if self.config['Architecture']["algorithm"] in ["Distillation",
                                                            ]:  # distillation model
                for key in self.config['Architecture']["Models"]:
                    if self.config['Architecture']['Models'][key]['Head'][
                        'name'] == 'MultiHead':  # for multi head
                        out_channels_list = {}
                        if self.config['PostProcess'][
                            'name'] == 'DistillationSARLabelDecode':
                            char_num = char_num - 2
                        out_channels_list['CTCLabelDecode'] = char_num
                        out_channels_list['SARLabelDecode'] = char_num + 2
                        self.config['Architecture']['Models'][key]['Head'][
                            'out_channels_list'] = out_channels_list
                    else:
                        self.config['Architecture']["Models"][key]["Head"][
                            'out_channels'] = char_num
            elif self.config['Architecture']['Head'][
                'name'] == 'MultiHead':  # for multi head
                out_channels_list = {}
                if self.config['PostProcess']['name'] == 'SARLabelDecode':
                    char_num = char_num - 2
                out_channels_list['CTCLabelDecode'] = char_num
                out_channels_list['SARLabelDecode'] = char_num + 2
                self.config['Architecture']['Head'][
                    'out_channels_list'] = out_channels_list
            else:  # base rec model
                self.config['Architecture']["Head"]['out_channels'] = char_num

        # 加载模型
        self.model = build_model(config['Architecture'])
        # load_model(config, self.model)
        init_model(self.config, self.model)
        self.model.eval()

    def forward(self, x):
        # x = paddle.transpose(x, [0,3,1,2])
        # x = x / 255.0
        # x = (x - self.mean) / self.std

        model_out = self.model(x)

        # return model_out
        preds_idx = model_out.argmax(axis=2, name='class').astype('float32')
        # preds_idx = model_out.argmax(axis=2, name='class')
        preds_prob = model_out.max(axis=2, name='score').astype('float32')
        return preds_idx, preds_prob

EXPORT_ONNX = True
DYNAMIC = False

if __name__ == '__main__':
    config, device = getArgs()
    model_crnn = CRNN(config, device=device)

    # 构建输入数据images:
    image_path = "1.jpg"
    img = cv2.imread(image_path)
    img = cv2.resize(img, (320, 32))
    print('input data:', img.shape)
    img = img.astype(np.float32)
    img = img.transpose((2, 0, 1)) / 255
    input_data = img[np.newaxis, :]
    print('input data:', input_data.shape)
    x = paddle.to_tensor(input_data)
    print('input data:', x.shape)

    output_idx, output_prob = model_crnn(x)
    print('output_idx: ', output_idx)
    print('output_prob: ', output_prob)

    input_spec = paddle.static.InputSpec.from_tensor(x,  name='input')
    onnx_save_path = "./export_onnx"
    if EXPORT_ONNX:
        onnx_model_name = onnx_save_path + "/char_recognize_20230526_v1"
        if DYNAMIC:
            input_spec = paddle.static.InputSpec(
                shape=[None, 32, 320, 3], dtype='float32',  name='input')

        # ONNX模型导出
        paddle.onnx.export(model_crnn, onnx_model_name, input_spec=[input_spec], opset_version=11,
                           enable_onnx_checker=True, output_spec=[output_idx, output_prob])

转换后的网络结构绘制出来,绘制使用的工具Netron

 绘制出的起始和末尾的网络结构:

对paddleOCR中的字符识别模型转ONNX_第1张图片

对paddleOCR中的字符识别模型转ONNX_第2张图片

测试ONNX的代码:

'''
测试转出的onnx模型
'''
import cv2
import numpy as np

import torch
import onnxruntime as rt
import math
import os

class TestOnnx:
    def __init__(self, onnx_file, character_dict_path, use_space_char=True):
        self.sess = rt.InferenceSession(onnx_file)
        # 获取输入节点名称
        self.input_names = [input.name for input in self.sess.get_inputs()]
        # 获取输出节点名称
        self.output_names = [output.name for output in self.sess.get_outputs()]

        self.character = []
        self.character.append("blank")
        with open(character_dict_path, "rb") as fin:
            lines = fin.readlines()
            for line in lines:
                line = line.decode('utf-8').strip("\n").strip("\r\n")
                self.character.append(line)
        if use_space_char:
            self.character.append(" ")

    def resize_norm_img(self, img, image_shape=[3, 32, 320]):
        imgC, imgH, imgW = image_shape
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        if image_shape[0] == 1:
            resized_image = resized_image / 255
            resized_image = resized_image[np.newaxis, :]
        else:
            resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    # # 准备模型运行的feed_dict
    def process(self, input_names, image):
        feed_dict = dict()
        for input_name in input_names:
            feed_dict[input_name] = image

        return feed_dict

    def get_ignored_tokens(self):
        return [0]

    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
        """ convert text-index into text-label. """
        result_list = []
        ignored_tokens = self.get_ignored_tokens()
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
            if is_remove_duplicate:
                selection[1:] = text_index[batch_idx][1:] != text_index[
                                                                 batch_idx][:-1]
            for ignored_token in ignored_tokens:
                selection &= text_index[batch_idx] != ignored_token

            char_list = [
                self.character[int(text_id)].replace('\n', '')
                for text_id in text_index[batch_idx][selection]
            ]
            if text_prob is not None:
                conf_list = text_prob[batch_idx][selection]
            else:
                conf_list = [1] * len(selection)
            if len(conf_list) == 0:
                conf_list = [0]

            text = ''.join(char_list)
            result_list.append((text, np.mean(conf_list).tolist()))

        return result_list

    def test(self, image_path):
        img_onnx = cv2.imread(image_path)
        # img_onnx = cv2.resize(img_onnx, (320, 32))
        # img_onnx = img_onnx.transpose((2, 0, 1)) / 255
        img_onnx = self.resize_norm_img(img_onnx)
        onnx_indata = img_onnx[np.newaxis, :, :, :]
        onnx_indata = torch.from_numpy(onnx_indata)
        # print('diff:', onnx_indata - input_data)
        print('image shape: ', onnx_indata.shape)
        onnx_indata = np.array(onnx_indata, dtype=np.float32)
        feed_dict = self.process(self.input_names, onnx_indata)

        output_onnx = self.sess.run(self.output_names, feed_dict)
        # print('output1 shape: ', output_onnx[0].shape)
        # print('output1: ', output_onnx[0])
        # print('output2 shape: ', output_onnx[1].shape)
        # print('output2: ', output_onnx[1])

        preds_idx = output_onnx[0]
        preds_prob = output_onnx[1]
        post_result = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)

        if isinstance(post_result, dict):
            rec_info = dict()
            for key in post_result:
                if len(post_result[key][0]) >= 2:
                    rec_info[key] = {
                        "label": post_result[key][0][0],
                        "score": float(post_result[key][0][1]),
                    }
            print(image_path, rec_info)
        else:
            if len(post_result[0]) >= 2:
                # info = post_result[0][0] + "\t" + str(post_result[0][1])
                info = post_result[0][0]
            print(image_path, info)




if __name__=='__main__':
    image_dir = "./sample/img"
    onnx_file = './export_onnx/char_recognize_20230526_v1.onnx'
    character_dict_path = './all_label_num_20230517.txt'

    testobj = TestOnnx(onnx_file, character_dict_path)

    files = os.listdir(image_dir)
    for file in files:
        image_path = os.path.join(image_dir, file)
        result = testobj.test(image_path)




模型转换结束。 

你可能感兴趣的:(ocr,onnx,python,开发语言)