百度PaddlePaddle_OCR文字识别_准确率98%

源码如下:如需解说、完整思路说明、配置文件,请到我其他文章找到联系方式

import argparse
import base64
import hashlib
import json
import logging as logger
import math
import os
import sys
import time
from threading import Thread

import cv2
import numpy as np
import paddle.fluid as fluid
import requests
from flask import request, Flask, Request
from paddle.fluid.core_avx import AnalysisConfig, create_paddle_predictor

__dir__ = os.path.dirname(os.path.abspath(__file__))

from werkzeug.serving import run_simple

sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))


class CharacterOps(object):
    """
    Convert between text-label and text-index
    """

    def __init__(self, config):
        self.character_type = config['character_type']
        self.loss_type = config['loss_type']
        self.max_text_len = config['max_text_length']
        # use the default dictionary(36 char)
        if self.character_type == "en":
            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
            dict_character = list(self.character_str)
        # use the custom dictionary
        elif self.character_type == "ch":
            character_dict_path = config['character_dict_path']
            add_space = False
            if 'use_space_char' in config:
                add_space = config['use_space_char']
            self.character_str = ""
            with open(character_dict_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    line = line.decode('utf-8').strip("\n").strip("\r\n")
                    self.character_str += line
            if add_space:
                self.character_str += " "
            dict_character = list(self.character_str)
        else:
            self.character_str = None
        assert self.character_str is not None, \
            "Nonsupport type of the character: {}".format(self.character_str)
        self.beg_str = "sos"
        self.end_str = "eos"
        # add start and end str for attention
        # create char dict
        self.dict = {
     }
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def decode(self, text_index, is_remove_duplicate=False):
        """
        convert text-index into text-label.
        Args:
            text_index: text index for each image
            is_remove_duplicate: Whether to remove duplicate characters,
                                 The default is False
        Return:
            text: text label
        """
        char_list = []
        char_num = self.get_char_num()

        ignored_tokens = [char_num]

        for idx in range(len(text_index)):
            if text_index[idx] in ignored_tokens:
                continue
            if is_remove_duplicate:
                if idx > 0 and text_index[idx - 1] == text_index[idx]:
                    continue
            char_list.append(self.character[int(text_index[idx])])
        text = ''.join(char_list)
        return text

    def get_char_num(self):
        """
        Get character num
        """
        return len(self.character)

    def get_beg_end_flag_idx(self, beg_or_end):
        if self.loss_type == "attention":
            if beg_or_end == "beg":
                idx = np.array(self.dict[self.beg_str])
            elif beg_or_end == "end":
                idx = np.array(self.dict[self.end_str])
            else:
                assert False, "Unsupport type %s in get_beg_end_flag_idx" \
                              % beg_or_end
            return idx
        else:
            err = "error in get_beg_end_flag_idx when using the loss %s" \
                  % (self.loss_type)
            assert False, err


def create_predictor(args):
    model_file_path = "__model__"
    params_file_path = "params"
    if not os.path.exists(model_file_path):
        logger.info("not find __model__ file path {}".format(model_file_path))
        sys.exit(0)
    if not os.path.exists(params_file_path):
        logger.info("not find params file path {}".format(params_file_path))
        sys.exit(0)

    config = AnalysisConfig(model_file_path, params_file_path)

    config.disable_gpu()
    config.set_cpu_math_library_num_threads(6)
    if args.enable_mkldnn:
        config.set_mkldnn_cache_capacity(10)
        config.enable_mkldnn()

    config.disable_glog_info()

    if args.use_zero_copy_run:
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
        config.switch_use_feed_fetch_ops(False)
    else:
        config.switch_use_feed_fetch_ops(True)

    predictor = create_paddle_predictor(config)
    input_names = predictor.get_input_names()
    for name in input_names:
        input_tensor = predictor.get_input_tensor(name)
    output_names = predictor.get_output_names()
    output_tensors = []
    for output_name in output_names:
        output_tensor = predictor.get_output_tensor(output_name)
        output_tensors.append(output_tensor)
    return predictor, input_tensor, output_tensors


def initial_logger():
    FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
    logger.basicConfig(level=logger.INFO, format=FORMAT)
    logger1 = logger.getLogger(__name__)
    return logger1


class TextRecognizer(object):
    def __init__(self, args):
        if args.use_pdserving is False:
            self.predictor, self.input_tensor, self.output_tensors = \
                create_predictor(args)
            self.use_zero_copy_run = args.use_zero_copy_run
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
        self.character_type = args.rec_char_type
        self.rec_batch_num = args.rec_batch_num
        self.rec_algorithm = args.rec_algorithm
        self.text_len = args.max_text_length
        char_ops_params = {
     "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path,
                           "use_space_char": args.use_space_char, "max_text_length": args.max_text_length,
                           'loss_type': 'ctc'}

        self.loss_type = 'ctc'
        self.char_ops = CharacterOps(char_ops_params)

    def resize_norm_img(self, img, max_wh_ratio):
        imgC, imgH, imgW = self.rec_image_shape
        assert imgC == img.shape[2]
        wh_ratio = max(max_wh_ratio, imgW * 1.0 / imgH)
        if self.character_type == "ch":
            imgW = int((32 * wh_ratio))
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the recognition process
        indices = np.argsort(np.array(width_list))

        rec_res = [['', 0.0]] * img_num
        batch_num = self.rec_batch_num
        predict_time = 0
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
            max_wh_ratio = 0
            for ino in range(beg_img_no, end_img_no):
                h, w = img_list[indices[ino]].shape[0:2]
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
                norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                max_wh_ratio)
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)

            norm_img_batch = np.concatenate(norm_img_batch, axis=0)
            norm_img_batch = norm_img_batch.copy()

            starttime = time.time()
            if self.use_zero_copy_run:
                self.input_tensor.copy_from_cpu(norm_img_batch)
                self.predictor.zero_copy_run()
            else:
                norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
                self.predictor.run([norm_img_batch])

            rec_idx_batch = self.output_tensors[0].copy_to_cpu()
            rec_idx_lod = self.output_tensors[0].lod()[0]
            predict_batch = self.output_tensors[1].copy_to_cpu()
            predict_lod = self.output_tensors[1].lod()[0]
            elapse = time.time() - starttime
            predict_time += elapse
            for rno in range(len(rec_idx_lod) - 1):
                beg = rec_idx_lod[rno]
                end = rec_idx_lod[rno + 1]
                rec_idx_tmp = rec_idx_batch[beg:end, 0]
                preds_text = self.char_ops.decode(rec_idx_tmp)
                beg = predict_lod[rno]
                end = predict_lod[rno + 1]
                probs = predict_batch[beg:end, :]
                ind = np.argmax(probs, axis=1)
                blank = probs.shape[1]
                valid_ind = np.where(ind != (blank - 1))[0]
                if len(valid_ind) == 0:
                    continue
                score = np.mean(probs[valid_ind, ind[valid_ind]])
                rec_res[indices[beg_img_no + rno]] = [preds_text, score]

        return rec_res, predict_time


def parse_args():
    def str2bool(v):
        return v.lower() in ("true", "t", "1")

    parser = argparse.ArgumentParser()
    # params for prediction engine
    parser.add_argument("--use_gpu", type=str2bool, default=False)

    # params for text recognizer
    parser.add_argument("--rec_algorithm", type=str, default='CRNN')
    parser.add_argument("--rec_model_dir", type=str, default='')
    parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
    parser.add_argument("--rec_char_type", type=str, default='ch')
    parser.add_argument("--rec_batch_num", type=int, default=120)
    parser.add_argument("--max_text_length", type=int, default=25)
    parser.add_argument(
        "--rec_char_dict_path",
        type=str,
        default="ppocr_keys_v1.txt")
    parser.add_argument("--use_space_char", type=str2bool, default=True)

    parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
    parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)

    parser.add_argument("--use_pdserving", type=str2bool, default=False)

    return parser.parse_args()


def base64_to_image(base64_code):
    """将base64的数据转换成rgb格式的图像矩阵"""
    img_data = base64.b64decode(base64_code)
    img_array = np.frombuffer(img_data, np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    return img


def main(args, image_str):
    img_list = []
    try:
        img = base64_to_image(image_str)
        img_list.append(img)
    except Exception as e:
        print(e)
        return 'img_str'
    try:
        text_recognizer = TextRecognizer(args)
        rec_res, predict_time = text_recognizer(img_list)
    except Exception as e:
        print(e)
        return 'text_recognizer'
    if rec_res:
        print("Predict:%s" % (rec_res[0]))
        print("Total predict time for %d images:%.3f" %
              (len(img_list), predict_time))
        return rec_res[0][0]
    else:
        return 'text_recognizer'


app = Flask('ocr')


@app.route('/ocr', methods=['POST'])  # 代表首页
def ocr():
    try:
        json_str = request.json
    except Exception as e:
        print(e)
        return json.dumps({
     
            'status': 0,
            'msg': 'json wrong!'
        })
    if json_str:
        keys = json_str.keys()
        if 'code' in keys:
            if 'image' in keys:
                image_str = json_str['image']
                if image_str:
                    code = hashlib.new('md5', md5_str.encode(encoding='UTF-8')).hexdigest()
                    if code == json_str['code']:
                        rec_res = main(parse_args(), image_str)
                        if rec_res == 'img_str':
                            print('该base64字符串无法解析')
                            return json.dumps({
     
                                'status': -1,
                                'msg': 'The base64 string cannot be parsed'
                            })
                        elif rec_res == 'text_recognizer':
                            print('图片识别异常')
                            return json.dumps({
     
                                'status': -1,
                                'msg': 'The picture is not recognized'
                            })
                        else:
                            return json.dumps({
     
                                'status': 1,
                                'data': rec_res
                            })
                    else:
                        return json.dumps({
     
                            'status': -1,
                            'msg': 'Code verification failed'
                        })
                else:
                    return json.dumps({
     
                        'status': -1,
                        'msg': 'The parameter is empty or the parameter is not standard'
                    })
            else:
                return json.dumps({
     
                    'status': -1,
                    'msg': 'image is null'
                })
        else:
            return json.dumps({
     
                'status': 0,
                'msg': 'Missing parameter'
            })
    else:
        return json.dumps({
     
            'status': 0,
            'msg': 'json is null'
        })


def application():
    while True:
        dd = requests.get(url_bert)
        print(dd.text)
        time.sleep(10)


def start_app():
    app.run('192.168.0.128', port=52013)  # 运行程序


if __name__ == '__main__':
    print('start app server!')
    url_bert = 'http://192.168.0.128:8080/HT/api/TaskSave?task=PROCPPS.OCR文字识别服务¬ice=1&key=202101041434'
    Thread(target=start_app).start()
    Thread(target=application).start()
    # app.run(host='192.168.0.128', port=52013)  # 运行程序
    print('end app server!')

你可能感兴趣的:(算法,深度学习,数据挖掘,OCR,文字识别,百度大脑,Paddle,深度学习)