源码如下:如需解说、完整思路说明、配置文件,请到我其他文章找到联系方式
import argparse
import base64
import hashlib
import json
import logging as logger
import math
import os
import sys
import time
from threading import Thread
import cv2
import numpy as np
import paddle.fluid as fluid
import requests
from flask import request, Flask, Request
from paddle.fluid.core_avx import AnalysisConfig, create_paddle_predictor
__dir__ = os.path.dirname(os.path.abspath(__file__))
from werkzeug.serving import run_simple
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
class CharacterOps(object):
"""
Convert between text-label and text-index
"""
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
# use the default dictionary(36 char)
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
if add_space:
self.character_str += " "
dict_character = list(self.character_str)
else:
self.character_str = None
assert self.character_str is not None, \
"Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
# add start and end str for attention
# create char dict
self.dict = {
}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def decode(self, text_index, is_remove_duplicate=False):
"""
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list = []
char_num = self.get_char_num()
ignored_tokens = [char_num]
for idx in range(len(text_index)):
if text_index[idx] in ignored_tokens:
continue
if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[int(text_index[idx])])
text = ''.join(char_list)
return text
def get_char_num(self):
"""
Get character num
"""
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
if self.loss_type == "attention":
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
else:
err = "error in get_beg_end_flag_idx when using the loss %s" \
% (self.loss_type)
assert False, err
def create_predictor(args):
model_file_path = "__model__"
params_file_path = "params"
if not os.path.exists(model_file_path):
logger.info("not find __model__ file path {}".format(model_file_path))
sys.exit(0)
if not os.path.exists(params_file_path):
logger.info("not find params file path {}".format(params_file_path))
sys.exit(0)
config = AnalysisConfig(model_file_path, params_file_path)
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if args.enable_mkldnn:
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
if args.use_zero_copy_run:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
else:
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names:
output_tensor = predictor.get_output_tensor(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors
def initial_logger():
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logger.basicConfig(level=logger.INFO, format=FORMAT)
logger1 = logger.getLogger(__name__)
return logger1
class TextRecognizer(object):
def __init__(self, args):
if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors = \
create_predictor(args)
self.use_zero_copy_run = args.use_zero_copy_run
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
char_ops_params = {
"character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char, "max_text_length": args.max_text_length,
'loss_type': 'ctc'}
self.loss_type = 'ctc'
self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
wh_ratio = max(max_wh_ratio, imgW * 1.0 / imgH)
if self.character_type == "ch":
imgW = int((32 * wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
predict_time = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0]
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = predict_batch[beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
return rec_res, predict_time
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=False)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default='')
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=120)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
"--rec_char_dict_path",
type=str,
default="ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
return parser.parse_args()
def base64_to_image(base64_code):
"""将base64的数据转换成rgb格式的图像矩阵"""
img_data = base64.b64decode(base64_code)
img_array = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return img
def main(args, image_str):
img_list = []
try:
img = base64_to_image(image_str)
img_list.append(img)
except Exception as e:
print(e)
return 'img_str'
try:
text_recognizer = TextRecognizer(args)
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
return 'text_recognizer'
if rec_res:
print("Predict:%s" % (rec_res[0]))
print("Total predict time for %d images:%.3f" %
(len(img_list), predict_time))
return rec_res[0][0]
else:
return 'text_recognizer'
app = Flask('ocr')
@app.route('/ocr', methods=['POST']) # 代表首页
def ocr():
try:
json_str = request.json
except Exception as e:
print(e)
return json.dumps({
'status': 0,
'msg': 'json wrong!'
})
if json_str:
keys = json_str.keys()
if 'code' in keys:
if 'image' in keys:
image_str = json_str['image']
if image_str:
code = hashlib.new('md5', md5_str.encode(encoding='UTF-8')).hexdigest()
if code == json_str['code']:
rec_res = main(parse_args(), image_str)
if rec_res == 'img_str':
print('该base64字符串无法解析')
return json.dumps({
'status': -1,
'msg': 'The base64 string cannot be parsed'
})
elif rec_res == 'text_recognizer':
print('图片识别异常')
return json.dumps({
'status': -1,
'msg': 'The picture is not recognized'
})
else:
return json.dumps({
'status': 1,
'data': rec_res
})
else:
return json.dumps({
'status': -1,
'msg': 'Code verification failed'
})
else:
return json.dumps({
'status': -1,
'msg': 'The parameter is empty or the parameter is not standard'
})
else:
return json.dumps({
'status': -1,
'msg': 'image is null'
})
else:
return json.dumps({
'status': 0,
'msg': 'Missing parameter'
})
else:
return json.dumps({
'status': 0,
'msg': 'json is null'
})
def application():
while True:
dd = requests.get(url_bert)
print(dd.text)
time.sleep(10)
def start_app():
app.run('192.168.0.128', port=52013) # 运行程序
if __name__ == '__main__':
print('start app server!')
url_bert = 'http://192.168.0.128:8080/HT/api/TaskSave?task=PROCPPS.OCR文字识别服务¬ice=1&key=202101041434'
Thread(target=start_app).start()
Thread(target=application).start()
# app.run(host='192.168.0.128', port=52013) # 运行程序
print('end app server!')