作为大三的我,前一段时间搞了手写数字识别,什么支持向量机啊,人工神经网络啊,knn啊,都玩过了,但仅仅是停留在人家公开的训练数据集上,而拿来测试用的图片也是人家的,比如mnist上的图片是这样的。
而我们拍的照片却是这样的
这样就带来了一个问题,测试的结果就不对了,用这样的图片去测试训练好的模型,可能有时候精度连10%都达不到。
这样,有必要将MNIST的生成过程学习一遍。
MNIST数据集是一个手写数字的集合,包含了60000个训练集和10000个测试集。每一个数字都是20x20的,包含在28x28的图片里面。这对于我们预处理是非常重要的。
#我们先预处理图片
i = 0
for no in [8,0,4,3]:
gray = cv2.imread("own_"+str(no)+".png", cv2.IMREAD_GRAYSCALE)
gray = cv2.resize(255-gray,(28,28))
cv2.imwrite("preprocessing/proImage_"+str(no)+".png",gray)
得到的图片是这样的:
这相对于原始图片已经好多了,但我们还需要改进这。
我们添加下面的代码:
(thresh, gray) = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
这段代码的含义,我不是学习图像处理的,于是查了一点资料,这段代码的含义就是图像阈值化处理,怎么说呢,就是图像中有超过阈值的一律归为255,比如这段代码中的阈值为128,只要图片中有像素超过128的就归为255,这样处理出来的效果为:
到这里我们已经处理了这个图片,但有一个最基本的问题就是,数字有可能不在中间,因为MNIST数据集所有的数字都在中间。
首先,我们想将调整20x20的数字位置。
由于对图像处理不太熟悉,这里先贴上代码:
import numpy as np
import cv2
import math
from scipy import ndimage
gray = cv2.imread("own_4.png", cv2.IMREAD_GRAYSCALE)
(thresh, gray) = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
gray = cv2.resize(255-gray,(28,28))
while np.sum(gray[0]) == 0:
gray = gray[1:]
while np.sum(gray[:,0]) == 0:
gray = np.delete(gray,0,1)
while np.sum(gray[-1]) == 0:
gray = gray[:-1]
while np.sum(gray[:,-1]) == 0:
gray = np.delete(gray,-1,1)
rows,cols = gray.shape
if rows>cols:
factor = 20.0/rows
rows = 20
cols = int(round(cols*factor))
gray = cv2.resize(gray,(cols,rows))
else:
factor = 20.0/cols
cols = 20
rows = int(round(rows*factor))
gray = cv2.resize(gray,(cols,rows))
colsPadding = (int(math.ceil((28-cols)/2.0)),int(math.floor((28-cols)/2.0)))
rowsPadding = (int(math.ceil((28-rows)/2.0)),int(math.floor((28-rows)/2.0)))
gray = np.lib.pad(gray,(rowsPadding,colsPadding),'constant')
def getBestShift(img):
cy,cx = ndimage.measurements.center_of_mass(img)
rows,cols = img.shape
shiftx = np.round(cols/2.0-cx).astype(int)
shifty = np.round(rows/2.0-cy).astype(int)
return shiftx,shifty
def shift(img,sx,sy):
rows,cols = img.shape
M = np.float32([[1,0,sx],[0,1,sy]])
shifted = cv2.warpAffine(img,M,(cols,rows))
return shifted
shiftx,shifty = getBestShift(gray)
shifted = shift(gray,shiftx,shifty)
gray = shifted
cv2.imwrite("test.png",gray)
这段代码的意思就是,让不规则的图片变为规则的图片如MNIST,比如这样的:
处理之后是这样的:
这样,就完成了图片的处理。
通过对图片的处理化,我们可以将其应用在实际应用当中。
参考资料:http://openmachin.es/blog/tensorflow-mnist
接下来学习多数字识别,如这样的:
将其中的每一个数字都识别出来。^_^