K-NN算法在手写数字上能够有不错的效果,但还有一个问题需要指出,K-NN是一种懒惰学习算法,即在训练的时候只需要将数据进行存储即可,但是测试的时候需要对整个数据集进行运算。这样的结果显而易见。实时性差。对于这个Kaggle竞赛,42000个训练集,以及28000个测试集,7700K单核运算一次测试集的时间大概在半小时左右。勉强能够实现单个数字的实时识别。
#read csv
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.neighbors import KNeighborsClassifier
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
submission = pd.read_csv('sample_submission.csv')#读入这个的原因是我懒
训练集中,每个元素都含有一个标签label,测试集中,这个标签是需要我们运算出来的。也就是识别出图片是哪一个数字。
# save the labels to a Pandas series target
target = train[['label']]
# Drop the label feature
train_set = train.drop("label",axis=1)
第三步是调用了sklearn中的方法,生成kNN的对象,然后传入训练集进行训练
kNN = 1
neigh = KNeighborsClassifier(n_neighbors = kNN)
neigh.fit(train_set, target.values.ravel())
第四步,输入测试数据进行预测
[m1, n1] = submission.shape
for i in range(0, m1):
predict = neigh.predict(test.loc[i].values.reshape(1, -1))
submission[m1,1] = predict
submission.to_csv('submission_kNN_'+str(kNN)+'.csv', index=False)
这里我做了一个测试当k=1时在测试集上的正确率大约有97%,当k在2到11时,正确率反而都要低一些,在96%上下波动。