语言:python
平台:pycharm
库: cv2
numpy
keras(这个需要先安装fensorflow库)
手写数字识别,是很多深度学习教程里的入门第一例,但是这些教程往往只告诉了你怎么去构造神经网络,训练模型,最后得出一个准确率,但是很少有教程告诉如何运用这个你训练好模型去识别你自己手写的数字。所以本文会介绍如何对你拍下来的手写数字进行预处理,然后用训练好的模型去识别这个手写数字。
本文运用的方法是svm支持向量机的方式,是一种机器学习,而非深度学习,这种方法的优点是训练模型时的计算量小,60000张的数据集用cpu跑也可以2分钟左右结束。后期博主也将出一篇深度学习方法的手写数字识别。
不管用哪种方法,对于手写数字的图片的预处理都是一样的,所以本文着重介绍图片的预处理。
1.训练(我也不是很能讲解清楚,网上有教程)
2.图像预处理(我主要介绍这部分)
3.测试
这部分可以参考https://keras-lx.blog.csdn.net/article/details/111693750
import cv2
import numpy as np
from keras.datasets import mnist
from keras import utils
if __name__ == '__main__':
# 直接使用Keras载入的训练数据(60000, 28, 28) (60000,)
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 变换数据的形状并归一化
train_images = train_images.reshape(train_images.shape[0], -1) # (60000, 784)
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape(test_images.shape[0], -1)
test_images = test_images.astype('float32') / 255
# 将标签数据转为int32 并且形状为(60000,1)
train_labels = train_labels.astype(np.int32)
test_labels = test_labels.astype(np.int32)
train_labels = train_labels.reshape(-1, 1)
test_labels = test_labels.reshape(-1, 1)
# 创建svm模型
svm = cv2.ml.SVM_create()
# 设置类型为SVM_C_SVC代表分类
svm.setType(cv2.ml.SVM_C_SVC)
# 设置核函数
svm.setKernel(cv2.ml.SVM_POLY)
# 设置其它属性
svm.setGamma(3)
svm.setDegree(3)
# 设置迭代终止条件
svm.setTermCriteria((cv2.TermCriteria_MAX_ITER, 300, 1e-3))
# 训练
svm.train(train_images, cv2.ml.ROW_SAMPLE, train_labels)
svm.save('mnist_svm.xml')
# 在测试数据上计算准确率
# 进行模型准确率的测试 结果是一个元组 第一个值为数据1的结果
test_pre = svm.predict(test_images)
test_ret = test_pre[1]
# 计算准确率
test_ret = test_ret.reshape(-1, )
test_labels = test_labels.reshape(-1, )
test_sum = (test_ret == test_labels)
acc = test_sum.mean()
print(acc)
我们拍了手写数字的照片后,需要对其进行预处理,以满足测试图片的格式要求。
因为训练集的图片都是28*28的二值图片,且数字处于图片正中央,且数字尽量充满图片,所以我们预处理要做的就是把自己拍下来的图片处理成类似的样子。
img0=cv.imread('shuzi2/3.jpg',0)
ret,img1=cv.threshold(img0, 100, 255, cv.THRESH_BINARY_INV) #阈值的选取和亮度有关
我这里是在白纸上写的黑字,所以在二值化的时候做了取反,就是把背景变成黑色,黑字变成白色(因为数据集的图片就是这样)
如果你是在黑板上写白字,那就用
ret,img1=cv.threshold(img0, 100, 255, cv.THRESH_BINARY)
因为背景不一定干净,所以二值化后可能存在噪点,去除噪点后再做一次膨胀,使数字更加清晰。
contours, hierarchy = cv.findContours(img2, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
for i in range(len(contours)):
area = cv.contourArea(contours[i])
if area <200: # '设定连通域最小阈值,小于该值被清理'
cv.drawContours(img2, [contours[i]], 0, 0, -1)
kernel2=np.ones((15,15),np.uint8) #膨胀
img3=cv.dilate(img2,kernel2)
因为我们拍的数字不一定就在正中间,且充满整个图片,所以需要把这个数字框出来。
代码如下:
contours, hierarchy = cv.findContours(img3, cv.RETR_LIST, cv.CHAIN_APPROX_SIMPLE)
x,y,w,h=cv.boundingRect(contours[0])
a=100
brcnt=np.array([[[x-a,y-a]],[[x+w+a,y-a]],[[x+w+a,y+h+a]],[[x-a,y+h+a]]])
cv.namedWindow('result',0)
cv.drawContours(img3,[brcnt],-1,(255,255,255),2)
cv.imshow('result',img3)
cv.waitKey(0)
img4=img3[y-a:y+h+a,x-a:x+w+a] #img4就是提取roi后的图片
cv.namedWindow('2',0)
cv.resizeWindow('2',600,600)
cv.imshow('2',img4)
cv.waitKey(0)
前面讲过为什么。
代码如下:
img5=cv.resize(img4,(28,28))
import cv2 as cv
import numpy as np
import glob
import os
from skimage import measure
img0=cv.imread('shuzi2/3.jpg',0)
cv.imshow('2',img0)
cv.waitKey(0)
ret,img1=cv.threshold(img0, 100, 255, cv.THRESH_BINARY_INV)
# print(type(img1))
# print(img1.shape)
# print(img1.size)
cv.imshow('2',img1)
cv.waitKey(0)
kernel1 = np.ones((3,3), np.uint8) #做一次膨胀
img2=cv.dilate(img1,kernel1)
# cv.imshow("2",img2)
# cv.waitKey(0)
'剔除小连通域'
contours, hierarchy = cv.findContours(img2, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
# print(len(contours),hierarchy)
for i in range(len(contours)):
area = cv.contourArea(contours[i])
if area <200: # '设定连通域最小阈值,小于该值被清理'
cv.drawContours(img2, [contours[i]], 0, 0, -1)
# cv.namedWindow('2',0)
# cv.resizeWindow('2',600,600)
# cv.imshow('2',img2)
# cv.waitKey(0)
kernel2=np.ones((15,15),np.uint8)
img3=cv.dilate(img2,kernel2)
# cv.namedWindow('2',0)
# cv.resizeWindow('2',600,600)
# cv.imshow('2',img3)
# cv.waitKey(0)
'roi提取'
contours, hierarchy = cv.findContours(img3, cv.RETR_LIST, cv.CHAIN_APPROX_SIMPLE)
x,y,w,h=cv.boundingRect(contours[0])
a=100
brcnt=np.array([[[x-a,y-a]],[[x+w+a,y-a]],[[x+w+a,y+h+a]],[[x-a,y+h+a]]])
cv.namedWindow('result',0)
cv.drawContours(img3,[brcnt],-1,(255,255,255),2)
cv.imshow('result',img3)
cv.waitKey(0)
img4=img3[y-a:y+h+a,x-a:x+w+a] #img4就是提取roi后的图片
cv.imshow('2',img4)
cv.waitKey(0)
img5=cv.resize(img4,(28,28))
图片预处理好了后,便可以引用训练好的模型对其进行识别。
当然,也可以直接把这个程序放在第二个程序后面。
import cv2
import numpy as np
if __name__=='__main__':
#读取图片
img=cv2.imread('shuzi4.png',0) #这里就是读取预处理好的图片了,当然你也可以把这个程序直接放在第二个程序后面,就不需要这一步了
img_sw=img.copy()
#将数据类型由uint8转为float32
img=img.astype(np.float32)
#图片形状由(28,28)转为(784,)
img=img.reshape(-1,)
#增加一个维度变为(1,784)
img=img.reshape(1,-1)
#图片数据归一化
img=img/255
#载入svm模型
svm=cv2.ml.SVM_load('mnist_svm.xml')
#进行预测
img_pre=svm.predict(img)
print(img_pre[1])
cv2.imshow('test',img_sw)
cv2.waitKey(0)