一般knn算法识别mnist数据集(代码)

本来是想弄个kd tree来玩玩knn的,但是mnist这样的数据集真是不好按维切分。把数据打印出来看了下,貌似灰度值大于3的都算是手写的印迹,着实不能取中值。既然这样,先拿一般的knn方法识别一下,看看效果和执行效率,再想办法这算一下mnist,玩玩kd tree吧。knn的基本原理在k-means、GMM聚类、KNN原理概述 有介绍,比较全的原理介绍在http://www.hankcs.com/ml/k-nearest-neighbor-method.html
下面是用knn 识别mnist数据集的代码,代码包括mnist的下载和抽取,以及knn测试,并计算了测试1000张图片所花费的时间。

#coding:utf-8

import numpy as np
import os
import gzip
from six.moves import urllib
import operator
from datetime import datetime

SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

#下载mnist数据集,仿照tensorflow的base.py中的写法。
def maybe_download(filename, path, source_url):
    if not os.path.exists(path):
        os.makedirs(path)
    filepath = os.path.join(path, filename)
    if not os.path.exists(filepath):
        urllib.request.urlretrieve(source_url, filepath)
    return filepath

#按32位读取,主要为读校验码、图片数量、尺寸准备的
#仿照tensorflow的mnist.py写的。
def _read32(bytestream):
    dt = np.dtype(np.uint32).newbyteorder('>')
    return np.frombuffer(bytestream.read(4), dtype=dt)[0]

#抽取图片,并按照需求,可将图片中的灰度值二值化,按照需求,可将二值化后的数据存成矩阵或者张量
#仿照tensorflow中mnist.py写的
def extract_images(input_file, is_value_binary, is_matrix):
    with gzip.open(input_file, 'rb') as zipf:
        magic = _read32(zipf)
        if magic !=2051:
            raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, input_file.name))
        num_images = _read32(zipf)
        rows = _read32(zipf)
        cols = _read32(zipf)
        print magic, num_images, rows, cols
        buf = zipf.read(rows * cols * num_images)
        data = np.frombuffer(buf, dtype=np.uint8)
        if is_matrix:
            data = data.reshape(num_images, rows*cols)
        else:
            data = data.reshape(num_images, rows, cols)
        if is_value_binary:
            return np.minimum(data, 1)
        else:
            return data


#抽取标签
#仿照tensorflow中mnist.py写的
def extract_labels(input_file):
    with gzip.open(input_file, 'rb') as zipf:
        magic = _read32(zipf)
        if magic != 2049:
            raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, input_file.name))
        num_items = _read32(zipf)
        buf = zipf.read(num_items)
        labels = np.frombuffer(buf, dtype=np.uint8)
        return labels

# 一般的knn分类,跟全部数据同时计算一般距离,然后找出最小距离的k张图,并找出这k张图片的标签,标签占比最大的为newInput的label
#copy大神http://blog.csdn.net/zouxy09/article/details/16955347的
def kNNClassify(newInput, dataSet, labels, k):
    numSamples = dataSet.shape[0] # shape[0] stands for the num of row
    init_shape = newInput.shape[0]
    newInput = newInput.reshape(1, init_shape)
    #np.tile(A,B):重复A B次,相当于重复[A]*B
    #print np.tile(newInput, (numSamples, 1)).shape
    diff = np.tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise
    squaredDiff = diff ** 2 # squared for the subtract
    squaredDist = np.sum(squaredDiff, axis = 1) # sum is performed by row
    distance = squaredDist ** 0.5
    sortedDistIndices = np.argsort(distance)

    classCount = {} # define a dictionary (can be append element)
    for i in xrange(k):
        ## step 3: choose the min k distance
        voteLabel = labels[sortedDistIndices[i]]

        ## step 4: count the times labels occur
        # when the key voteLabel is not in dictionary classCount, get()
        # will return 0
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1

    ## step 5: the max voted class will return
    maxCount = 0
    maxIndex = 0
    for key, value in classCount.items():
        if value > maxCount:
            maxCount = value
            maxIndex = key

    return maxIndex


maybe_download('train_images', 'data/mnist', SOURCE_URL+TRAIN_IMAGES)
maybe_download('train_labels', 'data/mnist', SOURCE_URL+TRAIN_LABELS)
maybe_download('test_images', 'data/mnist', SOURCE_URL+TEST_IMAGES)
maybe_download('test_labels', 'data/mnist', SOURCE_URL+TEST_LABELS)



# 主函数,先读图片,然后用于测试手写数字
#copy大神http://blog.csdn.net/zouxy09/article/details/16955347的
def testHandWritingClass():
    ## step 1: load data
    print "step 1: load data..."
    train_x = extract_images('data/mnist/train_images', True, True)
    train_y = extract_labels('data/mnist/train_labels')
    test_x = extract_images('data/mnist/test_images', True, True)
    test_y = extract_labels('data/mnist/test_labels')

    ## step 2: training...
    print "step 2: training..."
    pass

    ## step 3: testing
    print "step 3: testing..."
    a = datetime.now()
    numTestSamples = test_x.shape[0]
    matchCount = 0
    test_num = numTestSamples/10
    for i in xrange(test_num):

        predict = kNNClassify(test_x[i], train_x, train_y, 3)
        if predict == test_y[i]:
            matchCount += 1
        if i % 100 == 0:
            print "完成%d张图片"%(i)
    accuracy = float(matchCount) / test_num
    b = datetime.now()
    print "一共运行了%d秒"%((b-a).seconds)

    ## step 4: show the result
    print "step 4: show the result..."
    print 'The classify accuracy is: %.2f%%' % (accuracy * 100)

if __name__ == '__main__':
    testHandWritingClass()

执行后的结果如下:

step 1: load data...
2051 60000 28 28
2051 10000 28 28
step 2: training...
step 3: testing...
完成0张图片
完成100张图片
完成200张图片
完成300张图片
完成400张图片
完成500张图片
完成600张图片
完成700张图片
完成800张图片
完成900张图片
一共运行了234step 4: show the result...
The classify accuracy is: 96.20%

1000张图片运行时间234秒,时间开销大于简单的cnn,识别率高于96.2%,仅高于softmax回归,后者只有92%,多层感知机能达到98%的识别率,且训练速度快,测试更快。

你可能感兴趣的:(一般机器学习)