digits_ann .py :
import cv2 import cPickle import numpy as np import gzip def load_data(): mnist = gzip.open('mnist.pkl.gz', 'rb') training_data, classification_data, test_data = cPickle.load(mnist) mnist.close() return training_data, classification_data, test_data def wrap_data(): tr_d, va_d, te_d = load_data() training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]] training_results = [vectorized_result(y) for y in tr_d[1]] training_data = zip(training_inputs, training_results) validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]] validation_data = zip(validation_inputs, va_d[1]) test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]] test_data = zip(test_inputs, te_d[1]) return training_data, validation_data, test_data def vectorized_result(j): e = np.zeros((10, 1)) e[j] = 1.0 return e def create_ANN(hidden=20): ann = cv2.ml.ANN_MLP_create() ann.setLayerSizes(np.array([784, hidden, 10])) ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP) ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM) ann.setTermCriteria(( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 1)) return ann def train(ann, samples=10000, epochs=1): tr, val, test = wrap_data() for x in xrange(epochs): counter = 0 for img in tr: if(counter > samples): break if(counter % 1000 == 0): print "Epoch %d: Trained %d/%d" % (x, counter, samples) counter += 1 data, digit = img ann.train(np.array([data.ravel()], dtype=np.float32), cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32)) print "Epoch %d complete" % x return ann, test def test(ann, test_data): sample = np.array(test_data[0][0].ravel(), dtype=np.float32).reshape(28, 28) cv2.imshow("sample", sample) cv2.waitKey() print ann.predict(np.array([test_data[0][0].ravel()], dtype=np.float32)) def predict(ann, sample): resized = sample.copy() rows, cols = resized.shape if (rows != 28 or cols != 28) and rows * cols > 0: resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC) return ann.predict(np.array([resized.ravel()], dtype=np.float32))
main.py :
import cv2 import numpy as np import digits_ann as ANN def inside(r1, r2): x1,y1,w1,h1 = r1 x2,y2,w2,h2 = r2 if (x1 > x2) and (y1 > y2) and (x1+w1 < x2+w2) and (y1+h1 < y2 + h2): return True else: return False def wrap_digit(rect): x, y, w, h = rect padding = 5 hcenter = x + w/2 vcenter = y + h/2 if (h > w): w = h x = hcenter - (w/2) else: h = w y = vcenter - (h/2) return (x-padding, y-padding, w+padding, h+padding) ann, test_data = ANN.train(ANN.create_ANN(56), 20000) font = cv2.FONT_HERSHEY_SIMPLEX path = "./c.png" img = cv2.imread(path, cv2.IMREAD_UNCHANGED) bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) bw = cv2.GaussianBlur(bw, (7, 7), 0) ret, thbw = cv2.threshold(bw, 127, 255, cv2.THRESH_BINARY_INV) thbw = cv2.erode(thbw, np.ones((2,2), np.uint8), iterations=2) image, cntrs, hier = cv2.findContours(thbw.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) rectangles = [] for c in cntrs: r = x,y,w,h = cv2.boundingRect(c) a = cv2.contourArea(c) b = (img.shape[0]-3) * (img.shape[1] - 3) is_inside = False for q in rectangles: if inside(r, q): is_inside = True break if not is_inside: if not a == b: rectangles.append(r) for r in rectangles: x,y,w,h = wrap_digit(r) cv2.rectangle(img, (x,y), (x+w, y+h), (0, 255, 0), 2) roi = thbw[y:y+h, x:x+w] try: digit_class = int(ANN.predict(ann, roi.copy())[0]) except: continue cv2.putText(img, "%d" % digit_class, (x, y-1), font, 1, (0, 255, 0)) cv2.imshow("thbw", thbw) cv2.imshow("contours", img) cv2.imwrite("results.jpg", img) cv2.waitKey()
results:
compare pictures' results:
Conclusion :
This comparison demonstrated previously reveals its insufficient .
Reference to 《learning opencv3 computer vision with Python》
" mnist.pkl.gz " downloaded from http://www.cnblogs.com/xueliangliu/archive/2013/04/03/2997437.html