本案例在运行代码时会出现各种错误代码,代码列表如下:
错误代码:
1.TypeError: integer argument expected, got float
解决方式:原因是因为cv2.resize()函数内的参数是要求为整数,而python3中的'/'(除法),最后的结果自动转为浮点数,所以使用'//'运算,或是将返回值前面加“int”,如return (int(x-padding), int(y-padding), int(w+padding), int(h+padding))
2.OpenCV(3.4.1) Error: Assertion failed ((type == 5 || type == 6) && inputs.cols == layer_sizes[0]) in cv::ml::ANN_MLPImpl::predict, file D:\Build\OpenCV\opencv-3.4.1\modules\ml\src\ann_mlp.cpp, line 411
解决方式:出现此错误代码的原因主要可能是针对于测试图片的处理,代码读取数字时,查找数字边界不正确或边界矩形框超出图形边界,如:
程序代码:digits_ann.py
#coding=utf-8
import cv2
import pickle
import numpy as np
import gzip
"""OpenCV ANN Handwritten digit recognition example
Wraps OpenCV's own ANN by automating the loading of data and supplying default paramters,
such as 20 hidden layers, 10000 samples and 1 training epoch.
The load data code is taken from http://neuralnetworksanddeeplearning.com/chap1.html
by Michael Nielsen
"""
def load_data():
mnist = gzip.open(r'G:\Python_work\ANN\data/mnist.pkl.gz', 'rb')
training_data, classification_data, test_data = pickle.load(mnist,encoding="bytes") #注意:pickle.load(mnist,encoding="bytes")
mnist.close()
return (training_data, classification_data, test_data)
def wrap_data():
tr_d, va_d, te_d = load_data()
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
validation_data = zip(validation_inputs, va_d[1])
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_inputs, te_d[1])
return (training_data, validation_data, test_data)
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
def create_ANN(hidden = 20):
ann = cv2.ml.ANN_MLP_create()
ann.setLayerSizes(np.array([784, hidden, 10]))
ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
ann.setTermCriteria(( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 100, 1 ))
return ann
def train(ann, samples = 10000, epochs = 1):
tr, val, test = wrap_data()
for x in range(epochs):
counter = 0
for img in tr:
if (counter > samples):
break
if (counter % 1000 == 0):
print ("Epoch %d: Trained %d/%d" % (x, counter, samples))
counter += 1
data, digit = img
ann.train(np.array([data.ravel()], dtype=np.float32), cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32))
print ("Epoch %d complete" % x)
return ann, test
def test(ann, test_data):
sample = np.array(test_data[0][0].ravel(), dtype=np.float32).reshape(28, 28)
cv2.imshow("sample", sample)
cv2.waitKey()
print (ann.predict(np.array([test_data[0][0].ravel()], dtype=np.float32)))
def predict(ann, sample):
resized = sample.copy()
rows, cols = resized.shape
if (rows != 28 or cols != 28) and rows * cols > 0:
resized = cv2.resize(resized, (28, 28), interpolation = cv2.INTER_LINEAR)
return ann.predict(np.array([resized.ravel()], dtype=np.float32))
"""
usage:
ann, test_data = train(create_ANN())
test(ann, test_data)
"""
程序代码:digits_image_process.py
#coding=utf-8
import cv2
import numpy as np
import digits_ann as ANN
def inside(r1, r2):
x1,y1,w1,h1 = r1
x2,y2,w2,h2 = r2
if (x1 > x2) and (y1 > y2) and (x1+w1 < x2+w2) and (y1+h1 < y2 + h2):
return True
else:
return False
def wrap_digit(rect):
x, y, w, h = rect
padding = 5
hcenter = x + w/2
vcenter = y + h/2
roi = None
if (h > w):
w = h
x = hcenter - (w/2)
else:
h = w
y = vcenter - (h/2)
return (int(x-padding), int(y-padding), int(w+padding), int(h+padding))
# ann, test_data = ANN.train(ANN.create_ANN(56), 50000, 5)
ann, test_data = ANN.train(ANN.create_ANN(58), 50000, 5)
font = cv2.FONT_HERSHEY_SIMPLEX
# path = "./images/MNISTsamples.png"
path = r"G:\Python_work\ANN\images/digits.png"
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
bw = cv2.GaussianBlur(bw, (7,7), 0)
ret, thbw = cv2.threshold(bw, 127, 255, cv2.THRESH_BINARY_INV)
thbw = cv2.erode(thbw, np.ones((2,2), np.uint8), iterations = 2)
image, cntrs, hier = cv2.findContours(thbw.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
rectangles = []
for c in cntrs:
r = x,y,w,h = cv2.boundingRect(c)
a = cv2.contourArea(c)
b = (img.shape[0]-3) * (img.shape[1] - 3)
is_inside = False
for q in rectangles:
if inside(r, q):
is_inside = True
break
if not is_inside:
if not a == b:
rectangles.append(r)
for r in rectangles:
x,y,w,h = wrap_digit(r)
cv2.rectangle(img, (x,y), (x+w, y+h), (0, 255, 0), 2)
roi = thbw[y:y+h, x:x+w]
try:
digit_class = int(ANN.predict(ann, roi.copy())[0])
except:
continue
cv2.putText(img, "%d" % digit_class, (x, y-1), font, 1, (0, 255, 0))
cv2.imshow("thbw", thbw)
cv2.imshow("contours", img)
cv2.imwrite("sample.jpg", img)
cv2.waitKey()
参考文档:py4CV例子3Mnist识别和ANN 链接:https://www.cnblogs.com/jsxyhelu/p/8611935.html