Kaggle:Digit Recognizer (kNN手写数字识别)

1.KNN
K-NN全称K-Nearest Neighbor,直译就是K近邻算法。是机器学习最简单的算法之一,算法思想很简单,给定一个输入,在训练数据集中找出距离最近的K个值,然后进行投票统计,高票者胜出。当K=1时,算法又叫做最近邻算法,即找出与输入向量最近的向量的标记作为输出。最近邻在训练集上总是会有最好的结果,因此可能会带来过拟合的问题。

K-NN算法简单,但也有一些弊端,比如需要输入两个超参,一是距离的度量,我们可以使用欧氏距离,也可以使用向量夹角的余弦作为度量,不同的距离设置肯定会造成结果的不同。二是k值得选择,不同的k值必然会导致结果的不同。而这两个参数需要人工设置,一种简单的办法就是,尝试不同的超参,然后在数据集上进行测试,选择一种效果最好的。我尝试了一种AdaBoost加上KNN的方法来进行计算,即先计算出不同k值的运算结果,然后再根据不同k值得正确率分配给每个分类器一定的权值,然后将权值进行加权统计,选出最可能的值。

K-NN算法在手写数字上能够有不错的效果,但还有一个问题需要指出,K-NN是一种懒惰学习算法,即在训练的时候只需要将数据进行存储即可,但是测试的时候需要对整个数据集进行运算。这样的结果显而易见。实时性差。对于这个Kaggle竞赛,42000个训练集,以及28000个测试集,7700K单核运算一次测试集的时间大概在半小时左右。勉强能够实现单个数字的实时识别。

2.Digital Recognizer
数据集在Kaggle网站可以下载,相关python环境可以使用Anaconda进行一键配置。
第一步我们先读入数据集

#read csv
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.neighbors import KNeighborsClassifier

train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
submission = pd.read_csv('sample_submission.csv')#读入这个的原因是我懒
训练集中,每个元素都含有一个标签label,测试集中,这个标签是需要我们运算出来的。也就是识别出图片是哪一个数字。
第二步我们就将数据集进行切割,将标签和数据分割开来
# save the labels to a Pandas series target
target = train[['label']]
# Drop the label feature
train_set = train.drop("label",axis=1)

第三步是调用了sklearn中的方法,生成kNN的对象,然后传入训练集进行训练

kNN = 1
neigh = KNeighborsClassifier(n_neighbors = kNN)
neigh.fit(train_set, target.values.ravel())
第四步,输入测试数据进行预测
[m1, n1] = submission.shape
for i in range(0, m1):
  predict = neigh.predict(test.loc[i].values.reshape(1, -1))
  submission[m1,1] = predict
submission.to_csv('submission_kNN_'+str(kNN)+'.csv', index=False)  

这里我做了一个测试当k=1时在测试集上的正确率大约有97%,当k在2到11时,正确率反而都要低一些,在96%上下波动。


这个地方证明了一个问题,当k=1时确实可能出现过拟合的情况,但是数据集有大约42000个,还是比较大,确实说明数据集的增加是能够降低过拟合的问题。有兴趣可以写一个训练集只有小部分的情况再进行测试。

然后我考虑用多个k值得到的结果进行合并,将每个k值得分类器看做adaBoost中的一个小分类器,然后将多个分类器组合成一个大的分类器。但是结果正确率还是只有百分之96%,这个地方我有一个假设,就是当Adaboost方法中的子分类器的效果都比较好的时候,再累加各个分类器估计效果并不会优化太多,应该会趋向于所有子分类器的平均期望。这个应该可以从概率论上证明,这是我的一个猜想,有时间再思考一下这个问题。


总结一下这一次经历。用k-NN做手写识别确实能做,效果也还不错,但是速度会比较慢,比较依赖数据,也让我真切体会到了数据的宝贵,用了比较好的数据集,即使用一些比较简单的算法也能轻松的发现数据中的规律。考虑以后把神经网络看完,用神经网络或者学习中学到的别的方法,再做一次这个项目。


你可能感兴趣的:(机器学习)