机器学习 使用python+OpenCV实现knn算法手写数字识别

基本上照搬了http://lib.csdn.net/article/opencv/30167的代码,只是改了一点bug和增加了一点功能

输入就是直接在一个512*512大小的白色画布上画黑线,然后转化为01矩阵,用knn算法找训练数据中最相近的k个,现在应该是可以对所有字符进行训练和识别,只是训练数据中还只有数字而已,想识别更多更精确的话就需要自己多跑代码多写几百次,现在基本上一个数字写10次左右准确率就挺高了,并且每次识别的时候会将此次识别的数字和01矩阵存入训练数据文件夹中,增加以后识别的正确率,识别错了的话需要输入正确答案来扩充训练数据

/*--------------------------------------------------之前忘了说画完按回车了--------------------------*/

这是效果图:机器学习 使用python+OpenCV实现knn算法手写数字识别_第1张图片

机器学习 使用python+OpenCV实现knn算法手写数字识别_第2张图片

这是代码

knn.py

from numpy import *
import operator
import time
from os import listdir
def classify(inputPoint,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inputPoint,(dataSetSize,1))-dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[ sortedDistIndicies[i] ]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]
def img2vector(filename):
    returnVect = []
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect.append(int(lineStr[j]))
    return returnVect
def classnumCut(fileName):
    fileStr = fileName.split('.')[0]
    classNumStr = fileStr.split('_')[0]
    return classNumStr
def trainingDataSet():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')		  
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))						
    for i in range(m):
        fileNameStr = trainingFileList[i]
        hwLabels.append(classnumCut(fileNameStr))
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
        #print type(trainingMat)
    return hwLabels,trainingMat
draw.py
#encoding:utf-8
import cv2
import numpy as np #mouse callback function
from knn import *
ix,iy=-1,-1
ENTER = 10
drawing = False
#创建图像与窗口并将窗口与回调函数绑定
def in_img():
    for i in range(512):
        img[i,:]=255
    cv2.namedWindow('image')
    cv2.setMouseCallback('image',draw_circle)
    while(1):
        cv2.imshow('image',img)
        if cv2.waitKey(20)& 0xFF == ENTER:
            cv2.imwrite( '1.jpg',img)
            break
    cv2.destroyAllWindows()
def draw_circle(event,x,y,flags,param):
    global ix,iy,drawing
    if event==cv2.EVENT_LBUTTONDOWN:
        drawing=True
        ix,iy=x,y
    elif event==cv2.EVENT_MOUSEMOVE:
        if drawing==True:
            cv2.circle(img,(x,y),30,(0,0,0),-1)
    elif event==cv2.EVENT_LBUTTONUP:
        drawing=False
def classnum(fileName):
    fileStr = fileName.split('.')[0]
    classNumStr = fileStr.split('_')[0]
    num = int(fileStr.split('_')[1])
    return classNumStr,num
def read_image():
    img1 = cv2.imread('1.jpg', cv2.IMREAD_GRAYSCALE)
    res=cv2.resize(img1,(32,32),interpolation=cv2.INTER_CUBIC)
    cv2.imshow('2',res)
    pic=[]
    for i in range(32):
        for j in range(32):
            if res[i][j]<=200:
                res[i][j]=1
            else:
                res[i][j]=0
            pic.append(int(res[i][j]))
    hwLabels,trainingMat = trainingDataSet()
    classifierResult = classify(pic, trainingMat, hwLabels, 3)

    a = raw_input('is it '+ str(classifierResult)+'? input y/n.\n')
    c = 0
    if a == 'n' or a == 'N':
        b = raw_input('So what is it?\n')
        trainingFileList = listdir('trainingDigits')          
        m = len(trainingFileList)
        trainingMat = zeros((m,1024))                       
        for i in range(m):
            fileNameStr = trainingFileList[i]
            x,y = classnum(fileNameStr)
            if x == b:
                if y > c:
                    c = y
        c = c+1
        newfile = 'trainingDigits/' + str(b)+'_'+str(c)+('.txt')
        f=open(newfile,'w')
        for i in range(32):
            for j in range(32):
                f.write(str(res[i][j]))
            f.write("\n")
        f.close()
        print "I'll be smarter next time"
    else:
        b = str(classifierResult)
        trainingFileList = listdir('trainingDigits')          
        m = len(trainingFileList)
        trainingMat = zeros((m,1024))                       
        for i in range(m):
            fileNameStr = trainingFileList[i]
            x,y = classnum(fileNameStr)
            if x == b:
                if y > c:
                    c = y
        c = c+1
        newfile = 'trainingDigits/' + str(b)+'_'+str(c)+('.txt')
        f=open(newfile,'w')
        for i in range(32):
            for j in range(32):
                f.write(str(res[i][j]))
            f.write("\n")
        f.close()

def main():
    global img
    img=np.zeros((512,512,3),np.uint8)
    in_img()
    read_image()
if __name__=="__main__":
    main()

这是打包的代码和我自己写的几十个训练数据

https://download.csdn.net/download/qq_40051709/10282410

你可能感兴趣的:(机器学习)