【好玩的计算机视觉】KNN算法手写数字识别

OCR应用非常广泛,而且有许多方法,今天用KNN算法实现简单的0-9手写数字识别。本程序使用OpenCV 3.0和Python 3。


KNN算法是K近邻分类算法,属于机器学习中的监督学习,需要一定量的带标签的输入样本数据进行“训练”,然后就可以识别。我给“训练”打引号是因为其实KNN没有明显的前期训练过程,它是要给一个样本x分类,就从数据集中在x附近找离它最近的k各数据点,这k个数据点中包含的y类别最多,那么就把x的标签标记为y,这就完成了分类识别的过程。


首先,利用OpenCV自带的手写数字样本集digits.png来进行初始训练:



def initKnn():
    knn = cv2.ml.KNearest_create()
    img = cv2.imread('digits.png')
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)]
    train = np.array(cells).reshape(-1,400).astype(np.float32)
    trainLabel = np.repeat(np.arange(10),500)
    return knn, train, trainLabel
这是总共5000个数据,0-9各500个,我们读入图片后整理数据,这样得到的train和trainLabel依次对应,图像数据和标签。


def updateKnn(knn, train, trainLabel, newData=None, newDataLabel=None):
    if newData != None and newDataLabel != None:
        print(train.shape, newData.shape)
        newData = newData.reshape(-1,400).astype(np.float32)
        train = np.vstack((train,newData))
        trainLabel = np.hstack((trainLabel,newDataLabel))
    knn.train(train,cv2.ml.ROW_SAMPLE,trainLabel)
    return knn, train, trainLabel
updateKnn是增加自己的训练数据后更新Knn的操作。


def findRoi(frame, thresValue):
    rois = []
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.dilate(gray,None,iterations=2)
    gray2 = cv2.erode(gray2,None,iterations=2)
    edges = cv2.absdiff(gray,gray2)
    x = cv2.Sobel(edges,cv2.CV_16S,1,0)
    y = cv2.Sobel(edges,cv2.CV_16S,0,1)
    absX = cv2.convertScaleAbs(x)
    absY = cv2.convertScaleAbs(y)
    dst = cv2.addWeighted(absX,0.5,absY,0.5,0)
    ret, ddst = cv2.threshold(dst,thresValue,255,cv2.THRESH_BINARY)
    im, contours, hierarchy = cv2.findContours(ddst,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        if w > 10 and h > 20:
            rois.append((x,y,w,h))
    return rois, edges
findRoi函数是找到每个数字的位置,用包裹其最小矩形的左上顶点的坐标和该矩形长宽表示(x, y, w, h)。这里还用到了Sobel算子。edges是原始图像形态变换之后的灰度图,可以排除一些背景的影响,比如本子边缘、纸面的格子、手、笔以及影子等等,用edges来获取数字图像效果比Sobel获取的边界效果要好。

def findDigit(knn, roi, thresValue):
    ret, th = cv2.threshold(roi, thresValue, 255, cv2.THRESH_BINARY)
    th = cv2.resize(th,(20,20))
    out = th.reshape(-1,400).astype(np.float32)
    ret, result, neighbours, dist = knn.findNearest(out, k=5)
    return int(result[0][0]), th
findDigit函数是用KNN来分类,并将结果返回。th是用来手动输入训练数据时显示的图片。20x20pixel的尺寸是OpenCV自带digits.png中图像尺寸,因为我是在其基础上更新数据,所以沿用这个尺寸。


def concatenate(images):
    n = len(images)
    output = np.zeros(20*20*n).reshape(-1,20)
    for i in range(n):
        output[20*i:20*(i+1),:] = images[i]
    return output
concatenate函数是拼接数字图像并显示的,用来输入训练数据。


while True:
    ret, frame = cap.read()
    frame = frame[:,:426]
    rois, edges = findRoi(frame, 50)
    digits = []
    for r in rois:
        x, y, w, h = r
        digit, th = findDigit(knn, edges[y:y+h,x:x+w], 50)
        digits.append(cv2.resize(th,(20,20)))
        cv2.rectangle(frame, (x,y), (x+w,y+h), (153,153,0), 2)
        cv2.putText(frame, str(digit), (x,y), cv2.FONT_HERSHEY_SIMPLEX, 1, (127,0,255), 2)
    newEdges = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
    newFrame = np.hstack((frame,newEdges))
    cv2.imshow('frame', newFrame)
    videoFrame.write(newFrame)
    key = cv2.waitKey(1) & 0xff
    if key == ord(' '):
        break
    elif key == ord('x'):
        Nd = len(digits)
        output = concatenate(digits)
        showDigits = cv2.resize(output,(60,60*Nd))
        cv2.imshow('digits', showDigits)
        cv2.imwrite(str(count)+'.png', showDigits)
        count += 1
        if cv2.waitKey(0) & 0xff == ord('e'):
            pass
        print('input the digits(separate by space):')
        numbers = input().split(' ')
        Nn = len(numbers)
        if Nd != Nn:
            print('update KNN fail!')
            continue
        try:
            for i in range(Nn):
                numbers[i] = int(numbers[i])
        except:
            continue
        knn, train, trainLabel = updateKnn(knn, train, trainLabel, output, numbers)
        print('update KNN, Done!')
这是主函数循环部分,按“x”键会暂停屏幕并显示获取的数字图像,按“e”键会提示输入看到的数字,在终端输入数字用空格隔开,按回车如果显示“update KNN, Done!”则完成一次更新。下面是我用20多组0-9数字更新训练后得到的结果:

【好玩的计算机视觉】KNN算法手写数字识别_第1张图片


演示视频:http://www.bilibili.com/video/av4904541/

完整代码:https://github.com/littlethunder/digitsOCR


转载请注明:转自http://blog.csdn.net/littlethunder/article/details/51615237


你可能感兴趣的:(python,opencv,训练,knn,计算机视觉)