发票数据识别

import cv2
import math
import numpy as np
import matplotlib.pyplot as plt
import skimage
from PIL import Image
from pytesseract import pytesseract
from skimage import data,color,morphology,feature
import argparse
#import cvHelper

# 原始图像
img_ori1 = cv2.imread('TestData/taxi/IMG_3787.JPG')
img_ori2 = cv2.imread('TestData/taxi/IMG_3789.JPG')


imgs = [img_ori2]
resize_imgs = []
gray_resize_imgs = []

# 缩放图像
for idx,im in enumerate(imgs):
    width = 300.0          # 缩放 目标宽度
    r = width/im.shape[1] # 缩放因子
    dim = (int(width), int(im.shape[0]*r))
    img_resized = cv2.resize(im, dim, interpolation=cv2.INTER_AREA)
    resize_imgs.append(img_resized)
    gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY)
    gray_resize_imgs.append(gray)

# 显示图像
import pylab
cv2.namedWindow("ori img", cv2.WINDOW_AUTOSIZE)
cv2.moveWindow('ori img', 20, 24)
cv2.imshow('ori img', resize_imgs[0])
pylab.show()

im_at_mean = cv2.adaptiveThreshold(gray_resize_imgs[0], 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 3, 5)
cv2.imshow("im_at_mean", im_at_mean)
pylab.show()

b,g,r = cv2.split(resize_imgs[0])
th, dst = cv2.threshold(r, 160, 255, cv2.THRESH_BINARY)
cv2.imshow("r_threshold", dst)
pylab.show()

# 膨胀
kernel = np.ones((3, 3), np.uint8)
erosion = cv2.erode(dst, kernel, iterations=10)
# cv2.imshow("r_threshold_erosion", erosion)

# 膨胀后  小于2000的 转为白色 消除误差
binary,contours, hierarchy = cv2.findContours(erosion, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
filterContours = []
for contour in contours:
    M = cv2.moments(contour)
    if(M['m00']!=0):
        cx = int(M['m10'] / M['m00'])
        cy = int(M['m01'] / M['m00'])
        if cv2.contourArea(contour) > 300 and cv2.contourArea(contour) < 2000 and cx > erosion.shape[1]/2:
            filterContours.append(contour)
drawing = np.zeros(erosion.shape,np.uint8)
cv2.drawContours(drawing,filterContours,-1,255,-1)
cv2.imshow("erosion2", drawing)
pylab.show()

# 暂存这一步。需要叠加印章区域  才能过滤掉整个图片中的非打印文字
diff1 = drawing-erosion
th,dst = cv2.threshold(diff1, 10, 255, cv2.THRESH_BINARY)
# cv2.imshow("bitwise_erosion_drawing", dst)

# 获取印章区域
## 方法1 按颜色提取
hsv = cv2.cvtColor(resize_imgs[0],cv2.COLOR_BGR2HSV)
lower_blue = np.array([-6,100,100])
upper_blue = np.array([14,255,255])
mask = cv2.inRange(hsv,lower_blue,upper_blue)
res = cv2.bitwise_and(hsv,hsv,mask=mask)
# cv2.imshow('hsv', res)

# 填充印章轮廓
gray = cv2.cvtColor(res, cv2.COLOR_BGR2GRAY)
img,contours, hierarchy = cv2.findContours(gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
drawing2 = np.zeros(erosion.shape,np.uint8)
filterContours=[]
for contour in contours:
    if cv2.contourArea(contour) > 300:
        filterContours.append(contour)
cv2.drawContours(drawing2,filterContours,-1,255,-1)
# cv2.imshow("drawing2", drawing2)

# 叠加印章 轮廓作为蒙版
mask= dst - drawing2
retval, mask_fixed = cv2.threshold(mask, 50, 255, cv2.THRESH_BINARY)
# cv2.imshow("mask",mask_fixed)
kernel = np.ones((3, 3), np.uint8)
mask_fixed_erosion = cv2.erode(mask_fixed, kernel, iterations=2)
# cv2.imshow("mask_fixed_erosion",mask_fixed_erosion)
drawing3 = cv2.bitwise_and(im_at_mean,im_at_mean,mask=mask_fixed_erosion)
# cv2.imshow("drawing3",drawing3)
kernel = np.ones((3, 3), np.uint8)
drawing3 = cv2.bitwise_not(drawing3)
drawing3_erosion = cv2.erode(drawing3, kernel, iterations=1)
cv2.imshow("drawing3_erosion", drawing3_erosion)

# 统计
drawing3_erosion = cv2.bitwise_not(drawing3_erosion)
horizontal_sum = np.sum(drawing3_erosion, axis=1)
# plt.plot(horizontal_sum, range(horizontal_sum.shape[0]))
# plt.gca().invert_yaxis()
# plt.show()

def extract_peek_ranges_from_array(array_vals, minimun_val=1000, minimun_range=2):
    start_i = None
    end_i = None
    peek_ranges = []
    for i, val in enumerate(array_vals):
        if val > minimun_val and start_i is None:
            start_i = i
        elif val > minimun_val and start_i is not None:
            pass
        elif val < minimun_val and start_i is not None:
            end_i = i
            if end_i - start_i >= minimun_range:
                peek_ranges.append((start_i, end_i))
            start_i = None
            end_i = None
        elif val < minimun_val and start_i is None:
            pass
        else:
            pass
            # raise ValueError("cannot parse this case...")
    return peek_ranges

peek_ranges = extract_peek_ranges_from_array(horizontal_sum)

line_seg_adaptive_threshold = np.copy(resize_imgs[0])
for i, peek_range in enumerate(peek_ranges):
    x = 0
    y = peek_range[0]
    w = line_seg_adaptive_threshold.shape[1]
    h = peek_range[1] - y
    pt1 = (x, y)
    pt2 = (x + w, y + h)
    cv2.rectangle(line_seg_adaptive_threshold, pt1, pt2, 255)
# cv2.imshow('line image', line_seg_adaptive_threshold)+

start,end = peek_ranges[7]
rows = []
for idx in range(start,end,1):
    rows.append(drawing3_erosion[idx])
v_sum = np.sum(rows, axis=0)
plt.plot(v_sum, range(v_sum.shape[0]))
plt.gca().invert_yaxis()
plt.show()

vertical_peek_ranges2d = []
vertical_peek_ranges = extract_peek_ranges_from_array(v_sum, minimun_val=500,  minimun_range=1)
vertical_peek_ranges2d.append(vertical_peek_ranges)

# 切割字
tmpWords = []
for i in range(500,2500,10):
    vertical_peek_ranges = extract_peek_ranges_from_array(v_sum, minimun_val=i, minimun_range=3)
    print('循环', i)
    for vertical_range in vertical_peek_ranges:
        x = vertical_range[0]
        y = start
        w = vertical_range[1] - x
        h = end - y
        center_point = (x+w/2,y+h/2)
        if w <= 18 and w>=4: # 判断是否已经在tmpWords
            flag = 0
            for word in tmpWords:
                dist = np.sqrt(math.pow((word[4][0] - center_point[0]),2) + math.pow((word[4][1] - center_point[1]),2) )
                if dist < 5:
                    flag = 1
            if flag == 0:
                tmpWords.append((x, y, w, h, center_point))


toRecognizeWord = []

for word in tmpWords:
    x = word[0]
    y = word[1]
    w = word[2]
    h = word[3]
    pt1 = (x, y)
    pt2 = (x + w, y + h)

    toRecognizeWord.append(
        resize_imgs[0][y:y+h, x:x+w]
    )
    cv2.rectangle(line_seg_adaptive_threshold, pt1, pt2, (0, 0, 255))
cv2.imshow('line image', line_seg_adaptive_threshold)

你可能感兴趣的:(发票数据识别)