这是我第一次写博客,也是刚学的KNN,写``的目的就是为了增强我对该算法的理解。有问题请大家给我指出,谢谢啦。
一.算法部分
1.客观理解:最简单最初级的分类器是将全部的训练数据所对应的类别都记录下来,当测试对象的属性和某个训练对象的属性完全匹配时,便可以对其进行分类。但是怎么可能所有测试对象都会找到与之完全匹配的训练对象呢,其次就是存在一个测试对象同时与多个训练对象匹配,导致一个训练对象被分到了多个类的问题,基于这些问题呢,就产生了KNN。
KNN是通过测量不同特征值之间的距离进行分类。它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
下面通过一个简单的例子说明一下:如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。
2.(个人理解):其实KNN算法的核心就是找到与测试数据最近的训练数据。
3.算法核心(PYTHON代码部分):all_distance = np.sqrt(np.sum(np.square(np.tile(test_data,(dataSize,1))-train_data),axis=1))
二.带上代码实现手写数字识别
1.需要引用的库:
import numpy as np #需要对数字进行处理
from os import listdir #获取标签也就是训练集中手写数字的值
import operator
import datetime #求得运算时间
2.对手写数字的训练集进行处理:
def img2vector(filename): #将文本中32*32的矩阵转化成1*1024的矩阵,便于对数据进行处理
dataMat = np.zeros(1024,int)
fr = open(filename)
lines = fr.readlines()
for i in range(32):
for j in range(32):
dataMat[32*i+j] = lines[i][j]
return dataMat
def readDataSet(path): #读取所有的数据和其标签
file_titles = listdir(path)
file_lenth = len(file_titles)
dataSet = np.zeros([file_lenth,1024],int)
label = np.zeros([file_lenth])
for i in range(file_lenth):
file_title = file_titles[i]
dataSet[i] = img2vector(path+'/'+file_title)
label[i] = int(file_title.split('_')[0])
return dataSet,label
3.KNN算法:
def KNN(test_data,train_data,train_label,k):
dataSize = train_data.shape[0] #计算数据的总长度
all_distance = np.sqrt(np.sum(np.square(np.tile(test_data,(dataSize,1))-train_data),axis=1)) #用tile函数将test数据拓展成也训练数据相同的长度,然后在计算test与train的欧式距离
distance_sort_index = np.argsort(all_distance) #对得到的距离进行排序,获取其对应的下标
dictionary = {} #创建字典,用于对临近数据的储存
for i in range(k): #获取临近数据
predicted_value = train_label[distance_sort_index[i]]
dictionary[predicted_value]=dictionary.get(predicted_value,0)+1
sored_dictionary = sorted(dictionary.items(),key=operator.itemgetter(1)) #对获取的临近数据进行排序,获得出现次数最多的数据,就是其预测数值
return sored_dictionary[0][0]
4.全部程序:
import numpy as np
from os import listdir
import operator
import datetime
def KNN(test_data,train_data,train_label,k):
dataSize = train_data.shape[0] #计算数据的总长度
all_distance = np.sqrt(np.sum(np.square(np.tile(test_data,(dataSize,1))-train_data),axis=1)) #用tile函数将test数据拓展成也训练数据相同的长度,然后在计算test与train的欧式距离
distance_sort_index = np.argsort(all_distance) #对得到的距离进行排序,获取其对应的下标
dictionary = {} #创建字典,用于对临近数据的储存
for i in range(k): #获取临近数据
predicted_value = train_label[distance_sort_index[i]]
dictionary[predicted_value]=dictionary.get(predicted_value,0)+1
sored_dictionary = sorted(dictionary.items(),key=operator.itemgetter(1)) #对获取的临近数据进行排序,获得出现次数最多的数据,就是其预测数值
return sored_dictionary[0][0]
def img2vector(filename): #将文本中32*32的矩阵转化成1*1024的矩阵,便于对数据进行处理
dataMat = np.zeros(1024,int)
fr = open(filename)
lines = fr.readlines()
for i in range(32):
for j in range(32):
dataMat[32*i+j] = lines[i][j]
return dataMat
def readDataSet(path): #读取所有的数据和其标签
file_titles = listdir(path)
file_lenth = len(file_titles)
dataSet = np.zeros([file_lenth,1024],int)
label = np.zeros([file_lenth])
for i in range(file_lenth):
file_title = file_titles[i]
dataSet[i] = img2vector(path+'/'+file_title)
label[i] = int(file_title.split('_')[0])
return dataSet,label
def PredictNumber(filename):#输入一个32*32的矩阵,判断数值
matrix = img2vector(filename)
train_data, train_label = readDataSet('trainingDigits')
predict_number = KNN(matrix,train_data,train_label,3)
print(predict_number)
return
def main():
t1 = datetime.datetime.now() #计算训练所用的时间
nearest_neighbor_number = 3
train_data , train_label = readDataSet('trainingDigits') #获取训练数据集
test_data , test_label = readDataSet('testDigits') #获得测试数据集
test_titles = listdir('testDigits')
test_lenth = len(test_titles)
error_sum = 0 #计算出现错误的次数
for i in range(test_lenth):
real_number = test_label[i]
predict_number = KNN(test_data[i],train_data,train_label,nearest_neighbor_number)
print ("第",i+1,"组:","预测值:",predict_number,"真实值:",real_number)
if(predict_number != real_number):
error_sum += 1.0
t2 = datetime.datetime.now()
print('总共测试用时',t2-t1)
print("测试出现错误次数:",error_sum)
print ("\n错误率:",error_sum/float(test_lenth)*100,'%')
main()
本文主要参考来源:https://blog.csdn.net/zzz_cming/article/details/78938107
https://www.cnblogs.com/ybjourney/p/4702562.html