OpenCV3 计算机视觉 Python语言实现 用人工神经网络进行手写数字识别

本案例在运行代码时会出现各种错误代码,代码列表如下:

错误代码:

1.TypeError: integer argument expected, got float

解决方式:原因是因为cv2.resize()函数内的参数是要求为整数,而python3中的'/'(除法),最后的结果自动转为浮点数,所以使用'//'运算,或是将返回值前面加“int”,如return (int(x-padding), int(y-padding), int(w+padding), int(h+padding))

2.OpenCV(3.4.1) Error: Assertion failed ((type == 5 || type == 6) && inputs.cols == layer_sizes[0]) in cv::ml::ANN_MLPImpl::predict, file D:\Build\OpenCV\opencv-3.4.1\modules\ml\src\ann_mlp.cpp, line 411

解决方式:出现此错误代码的原因主要可能是针对于测试图片的处理,代码读取数字时,查找数字边界不正确或边界矩形框超出图形边界,如:

OpenCV3 计算机视觉 Python语言实现 用人工神经网络进行手写数字识别_第1张图片,或,或OpenCV3 计算机视觉 Python语言实现 用人工神经网络进行手写数字识别_第2张图片,通过对测试图片进行出行,错误代码可以解决。

程序代码:digits_ann.py

#coding=utf-8
import cv2
import pickle
import numpy as np
import gzip

"""OpenCV ANN Handwritten digit recognition example

Wraps OpenCV's own ANN by automating the loading of data and supplying default paramters,
such as 20 hidden layers, 10000 samples and 1 training epoch.

The load data code is taken from http://neuralnetworksanddeeplearning.com/chap1.html
by Michael Nielsen
"""

def load_data():
  mnist = gzip.open(r'G:\Python_work\ANN\data/mnist.pkl.gz', 'rb')
  training_data, classification_data, test_data = pickle.load(mnist,encoding="bytes") #注意:pickle.load(mnist,encoding="bytes")
  mnist.close()
  return (training_data, classification_data, test_data)

def wrap_data():
  tr_d, va_d, te_d = load_data()
  training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
  training_results = [vectorized_result(y) for y in tr_d[1]]
  training_data = zip(training_inputs, training_results)
  validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
  validation_data = zip(validation_inputs, va_d[1])
  test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
  test_data = zip(test_inputs, te_d[1])
  return (training_data, validation_data, test_data)

def vectorized_result(j):
  e = np.zeros((10, 1))
  e[j] = 1.0
  return e

def create_ANN(hidden = 20):
  ann = cv2.ml.ANN_MLP_create()
  ann.setLayerSizes(np.array([784, hidden, 10]))
  ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
  ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
  ann.setTermCriteria(( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 100, 1 ))
  return ann

def train(ann, samples = 10000, epochs = 1):
  tr, val, test = wrap_data()
  
  
  for x in range(epochs):
    counter = 0
    for img in tr:
      
      if (counter > samples):
        break
      if (counter % 1000 == 0):
        print ("Epoch %d: Trained %d/%d" % (x, counter, samples))
      counter += 1
      data, digit = img
      ann.train(np.array([data.ravel()], dtype=np.float32), cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32))
    print ("Epoch %d complete" % x)
  return ann, test
  
def test(ann, test_data):
  sample = np.array(test_data[0][0].ravel(), dtype=np.float32).reshape(28, 28)
  cv2.imshow("sample", sample)
  cv2.waitKey()
  print (ann.predict(np.array([test_data[0][0].ravel()], dtype=np.float32)))

def predict(ann, sample):
  resized = sample.copy()
  rows, cols = resized.shape
  if (rows != 28 or cols != 28) and rows * cols > 0:
    resized = cv2.resize(resized, (28, 28), interpolation = cv2.INTER_LINEAR)
  return ann.predict(np.array([resized.ravel()], dtype=np.float32))

"""
usage:
ann, test_data = train(create_ANN())
test(ann, test_data)
"""

程序代码:digits_image_process.py

#coding=utf-8
import cv2
import numpy as np
import digits_ann as ANN

def inside(r1, r2):
  x1,y1,w1,h1 = r1
  x2,y2,w2,h2 = r2
  if (x1 > x2) and (y1 > y2) and (x1+w1 < x2+w2) and (y1+h1 < y2 + h2):
    return True
  else:
    return False

def wrap_digit(rect):
  x, y, w, h = rect
  padding = 5
  hcenter = x + w/2
  vcenter = y + h/2
  roi = None
  if (h > w):
    w = h
    x = hcenter - (w/2)
  else:
    h = w
    y = vcenter - (h/2)
  return (int(x-padding), int(y-padding), int(w+padding), int(h+padding))

# ann, test_data = ANN.train(ANN.create_ANN(56), 50000, 5)
ann, test_data = ANN.train(ANN.create_ANN(58), 50000, 5)
font = cv2.FONT_HERSHEY_SIMPLEX

# path = "./images/MNISTsamples.png"
path = r"G:\Python_work\ANN\images/digits.png"
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
bw = cv2.GaussianBlur(bw, (7,7), 0)
ret, thbw = cv2.threshold(bw, 127, 255, cv2.THRESH_BINARY_INV)
thbw = cv2.erode(thbw, np.ones((2,2), np.uint8), iterations = 2)
image, cntrs, hier = cv2.findContours(thbw.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

rectangles = []

for c in cntrs:
  r = x,y,w,h = cv2.boundingRect(c)
  a = cv2.contourArea(c)
  b = (img.shape[0]-3) * (img.shape[1] - 3)
  
  is_inside = False
  for q in rectangles:
    if inside(r, q):
      is_inside = True
      break
  if not is_inside:
    if not a == b:
      rectangles.append(r)

for r in rectangles:
  x,y,w,h = wrap_digit(r) 
  cv2.rectangle(img, (x,y), (x+w, y+h), (0, 255, 0), 2)
  roi = thbw[y:y+h, x:x+w]
  
  try:
    digit_class = int(ANN.predict(ann, roi.copy())[0])
  except:
    continue
  cv2.putText(img, "%d" % digit_class, (x, y-1), font, 1, (0, 255, 0))

cv2.imshow("thbw", thbw)
cv2.imshow("contours", img)
cv2.imwrite("sample.jpg", img)
cv2.waitKey()
 

参考文档:py4CV例子3Mnist识别和ANN 链接:https://www.cnblogs.com/jsxyhelu/p/8611935.html

 

 

你可能感兴趣的:(OpenCV3 计算机视觉 Python语言实现 用人工神经网络进行手写数字识别)