K近邻分类器(KNN)手写数字(MNIST)识别

KNN(K-Nearest-Neighbor) 是分类算法中比较简单的一个算法。算法思想非常简单:对于一个未知类别的样例,我们在很多已知类别的样本中找出跟它最相近的K个样本,赋予该样例这K个样本中占多数的类别。
K近邻分类器(KNN)手写数字(MNIST)识别_第1张图片
如图中所示,如果我们选取K值为3,则将样本分类为三角形的类别。而如果K为5,则将样本分类为正方形的类别。这里也可以看出K值的选取很关键。

这里呢,我将用KNN做手写体数字的识别。另人惊异的是,用如此简单的算法也可以获得超过94%的识别准确率。
首先我先介绍一下我用的数据集MNIST,我有的是10500条已经标好类别的样本,我用其中500条做测试样例,用剩下10000条做训练集。其中每一个样本784 位0、1 加上一位类别组成,784位0/1 可以组成28*28的二值图。
K近邻分类器(KNN)手写数字(MNIST)识别_第2张图片
K近邻分类器(KNN)手写数字(MNIST)识别_第3张图片

下面是计算两个样例之间距离的公式,这也是最基本的欧式距离。
欧式距离公式

public static double calDistance(int[] a, int[] b) {
        double temp = 0;
        for (int x = 0; x < a.length; x++) {
            temp += (a[x] - b[x]) * (a[x] - b[x]);
        }
        return temp = Math.sqrt(temp);
    }

下面给出分类代码,这里我的程序是读入测试样例,然后逐条计算它与训练样本的距离,找出K个最接近的样本,统计K个中出现最多的类标赋予给测试样例。如果要用这个代码的话需要稍微改一改,用的编程语言都是Java

public static int classify(String filename, int[] a) throws IOException {
        FileReader fr = new FileReader(filename);
        BufferedReader bufr = new BufferedReader(fr);

        double[] d = new double[K];//存放K近邻的距离

        for (int x = 0; x < K; x++) {//先将所有K近邻的距离初始化为最大距离28
            d[x] = 28;
        }
        double temp = 0;
        int lable = 0;
        int[] num = new int[K];//记录对应距离的类标
        String str = null;
        int t = 0;
        while ((str = bufr.readLine()) != null && t++ < 10000) {
            int[] b = str2int(str.substring(0, str.length() - 1).split(","));
            temp = calDistance(a, b);
            lable = Integer.parseInt(str.substring(str.length() - 1));
            for (int x = 0; x < K; x++) {//找到K近邻的样本
                if (temp < d[x]) {
                    d[x] = temp;
                    num[x] = lable;
                    break;
                }
            }
        }
        bufr.close();
        int[] count = new int[10];
        for (int x = 0; x < K; x++) {//统计各数字出现次数
            count[num[x]]++;
        }
        int result = 0;
        for (int x = 1; x < 10; x++) {//找出出现次数最多的
            if (count[x] > count[result])
                result = x;
        }
        return result;
    }

进一步的改进
关于KNN的改进有以下几个方面:

  1. 加权重,这里的原理是距离测试样本最近的训练样本有比较高的权重。一般权重公式可以为距离的倒数。
  2. 换距离公式,可以换成cos距离
  3. 去除不重要的特征减少计算量;采用特殊的数据结构排序训练样本如kd-tree,减少计算距离次数。

数据集与整个项目的源码我都已经上传,点击下载。值得注意的是,在我的项目里面已经实现了汉明距离与cos距离两种不同距离衡量方法。有任何问题欢迎讨教

你可能感兴趣的:(数据挖掘,机器学习,模式识别经典算法)