kNN算法的核心思想是把测试数据和训练数据集通过一些方法计算得出的值(距离)进行排序,然后在前K个值中找出出现频数最高的那个分类.
如右图(取自百度百科),有 蓝正方形 和 红三角 两类,通过计算得到 绿圆(测试数据) 和各个 训练数据 的位置距离,
然后查看离 绿圆 最近的K个训练数据中哪个频数最高就把绿圆归为哪类.图中实线(K=3)中红三角最多所以把 绿圆和红三角 归为一类;而当取虚线(K=5)时 蓝正方形 最多,所以绿圆和 归为蓝正方形.
讲白了这个也就是猜,看离的近的哪类多就归为哪类.
从有图也可以看的出K的取值大小对准确率具有一定的影响.
例:手写数字识别
一.
这里我用的数据集是 UCI的手写数数据集:手写数字数据集
二.
#用一维数组读取图片的数据
num = []
for i in range(0,longs):
strl1 = strl[i][-21:-2]
strl2= strl1.split(' ')
#print(strl2)
for k in range(0,10):
if strl2[k] == '1':
num.append(k)
break
#读取对应的是什么数字
snum = []
for i in range(0,longs):
ss = strl[i][0:-22]
stl = ss.split(' ')
n = []
for k in range(0,len(stl)):
n.append(float(stl[k]))
snum.append(n)
三.
KNN算法不适用这一步所以跳过了.
四.
#测试数据数
testsize = 100
#打印信息用,显示剩余测试数
nmm = testsize
#错误数
error = 0
for i in range(0,testsize):
nnm = 800 + i * 2 #测试数据,从第800个开始,隔一个数据再取
o = [0,0,0,0,0,0,0,0,0,0] #记录排序后前K个距离各个分类的数目
max = 0 #查找o中频数最大分类用
oo = 0 #记录频数最大分类
test = [] #测试数据与各训练数据距离
testnum = [] #记录训练数据对应分类
#计算距离,使用欧式距离公式
for k in range(0,800):
nm = 0
for p in range(0,256):
nm += ((snum[nnm][p] - snum[k][p])*(snum[nnm][p] - snum[k][p]))
test.append(nm**0.5)
for l in range(0,800):
testnum.append(num[l])
bubble_sort(test,testnum) #把测试数据与训练数据距离和对应类比排序
#取前K个中频数最大的
for l in range(0,3):
o[testnum[l]]+=1
for l in range(0,10):
if o[l]>max:
max = o[l]
oo = l
nmm -= 1
print('test:',num[nnm],'-----answer:',oo,'-----remaining:',nmm)
if num[nnm] != oo:
error += 1
五.
('test:', 0, '-----answer:', 0, '-----ramaining:', 99)
('test:', 0, '-----answer:', 0, '-----ramaining:', 98)
('test:', 0, '-----answer:', 0, '-----ramaining:', 97)
('test:', 0, '-----answer:', 0, '-----ramaining:', 96)
('test:', 0, '-----answer:', 0, '-----ramaining:', 95)
('test:', 0, '-----answer:', 0, '-----ramaining:', 94)
('test:', 0, '-----answer:', 0, '-----ramaining:', 93)
('test:', 0, '-----answer:', 0, '-----ramaining:', 92)
('test:', 0, '-----answer:', 0, '-----ramaining:', 91)
('test:', 0, '-----answer:', 0, '-----ramaining:', 90)
('test:', 0, '-----answer:', 0, '-----ramaining:', 89)
('test:', 0, '-----answer:', 0, '-----ramaining:', 88)
('test:', 0, '-----answer:', 0, '-----ramaining:', 87)
('test:', 0, '-----answer:', 0, '-----ramaining:', 86)
('test:', 0, '-----answer:', 0, '-----ramaining:', 85)
('test:', 0, '-----answer:', 0, '-----ramaining:', 84)
('test:', 0, '-----answer:', 0, '-----ramaining:', 83)
('test:', 0, '-----answer:', 0, '-----ramaining:', 82)
('test:', 0, '-----answer:', 0, '-----ramaining:', 81)
('test:', 1, '-----answer:', 1, '-----ramaining:', 80)
('test:', 1, '-----answer:', 1, '-----ramaining:', 79)
('test:', 1, '-----answer:', 1, '-----ramaining:', 78)
('test:', 1, '-----answer:', 1, '-----ramaining:', 77)
('test:', 1, '-----answer:', 1, '-----ramaining:', 76)
('test:', 1, '-----answer:', 1, '-----ramaining:', 75)
('test:', 1, '-----answer:', 1, '-----ramaining:', 74)
('test:', 1, '-----answer:', 1, '-----ramaining:', 73)
('test:', 1, '-----answer:', 1, '-----ramaining:', 72)
('test:', 1, '-----answer:', 1, '-----ramaining:', 71)
('test:', 1, '-----answer:', 1, '-----ramaining:', 70)
('test:', 1, '-----answer:', 1, '-----ramaining:', 69)
('test:', 1, '-----answer:', 1, '-----ramaining:', 68)
('test:', 1, '-----answer:', 1, '-----ramaining:', 67)
('test:', 1, '-----answer:', 1, '-----ramaining:', 66)
('test:', 1, '-----answer:', 1, '-----ramaining:', 65)
('test:', 1, '-----answer:', 1, '-----ramaining:', 64)
('test:', 1, '-----answer:', 1, '-----ramaining:', 63)
('test:', 1, '-----answer:', 1, '-----ramaining:', 62)
('test:', 1, '-----answer:', 1, '-----ramaining:', 61)
('test:', 2, '-----answer:', 2, '-----ramaining:', 60)
('test:', 2, '-----answer:', 2, '-----ramaining:', 59)
('test:', 2, '-----answer:', 2, '-----ramaining:', 58)
('test:', 2, '-----answer:', 2, '-----ramaining:', 57)
('test:', 2, '-----answer:', 2, '-----ramaining:', 56)
('test:', 2, '-----answer:', 2, '-----ramaining:', 55)
('test:', 2, '-----answer:', 2, '-----ramaining:', 54)
('test:', 2, '-----answer:', 2, '-----ramaining:', 53)
('test:', 2, '-----answer:', 4, '-----ramaining:', 52)
('test:', 2, '-----answer:', 2, '-----ramaining:', 51)
('test:', 2, '-----answer:', 2, '-----ramaining:', 50)
('test:', 2, '-----answer:', 2, '-----ramaining:', 49)
('test:', 2, '-----answer:', 2, '-----ramaining:', 48)
('test:', 2, '-----answer:', 2, '-----ramaining:', 47)
('test:', 2, '-----answer:', 2, '-----ramaining:', 46)
('test:', 2, '-----answer:', 2, '-----ramaining:', 45)
('test:', 2, '-----answer:', 2, '-----ramaining:', 44)
('test:', 2, '-----answer:', 2, '-----ramaining:', 43)
('test:', 2, '-----answer:', 2, '-----ramaining:', 42)
('test:', 2, '-----answer:', 2, '-----ramaining:', 41)
('test:', 3, '-----answer:', 3, '-----ramaining:', 40)
('test:', 3, '-----answer:', 3, '-----ramaining:', 39)
('test:', 3, '-----answer:', 3, '-----ramaining:', 38)
('test:', 3, '-----answer:', 3, '-----ramaining:', 37)
('test:', 3, '-----answer:', 3, '-----ramaining:', 36)
('test:', 3, '-----answer:', 3, '-----ramaining:', 35)
('test:', 3, '-----answer:', 3, '-----ramaining:', 34)
('test:', 3, '-----answer:', 3, '-----ramaining:', 33)
('test:', 3, '-----answer:', 3, '-----ramaining:', 32)
('test:', 3, '-----answer:', 3, '-----ramaining:', 31)
('test:', 3, '-----answer:', 3, '-----ramaining:', 30)
('test:', 3, '-----answer:', 3, '-----ramaining:', 29)
('test:', 3, '-----answer:', 3, '-----ramaining:', 28)
('test:', 3, '-----answer:', 5, '-----ramaining:', 27)
('test:', 3, '-----answer:', 3, '-----ramaining:', 26)
('test:', 3, '-----answer:', 3, '-----ramaining:', 25)
('test:', 3, '-----answer:', 3, '-----ramaining:', 24)
('test:', 3, '-----answer:', 3, '-----ramaining:', 23)
('test:', 3, '-----answer:', 3, '-----ramaining:', 22)
('test:', 3, '-----answer:', 3, '-----ramaining:', 21)
('test:', 4, '-----answer:', 4, '-----ramaining:', 20)
('test:', 4, '-----answer:', 4, '-----ramaining:', 19)
('test:', 4, '-----answer:', 4, '-----ramaining:', 18)
('test:', 4, '-----answer:', 4, '-----ramaining:', 17)
('test:', 4, '-----answer:', 4, '-----ramaining:', 16)
('test:', 4, '-----answer:', 4, '-----ramaining:', 15)
('test:', 4, '-----answer:', 4, '-----ramaining:', 14)
('test:', 4, '-----answer:', 4, '-----ramaining:', 13)
('test:', 4, '-----answer:', 4, '-----ramaining:', 12)
('test:', 4, '-----answer:', 4, '-----ramaining:', 11)
('test:', 4, '-----answer:', 4, '-----ramaining:', 10)
('test:', 4, '-----answer:', 1, '-----ramaining:', 9)
('test:', 4, '-----answer:', 4, '-----ramaining:', 8)
('test:', 4, '-----answer:', 4, '-----ramaining:', 7)
('test:', 4, '-----answer:', 4, '-----ramaining:', 6)
('test:', 4, '-----answer:', 4, '-----ramaining:', 5)
('test:', 4, '-----answer:', 4, '-----ramaining:', 4)
('test:', 4, '-----answer:', 4, '-----ramaining:', 3)
('test:', 4, '-----answer:', 4, '-----ramaining:', 2)
('test:', 4, '-----answer:', 4, '-----ramaining:', 1)
('test:', 5, '-----answer:', 5, '-----ramaining:', 0)
('Test num:', 100, '-----error num:', 3.0, '-----', 3.0, '%')
可以看到进行了100次测试有3个猜错了,错误率在3%.
最后:
KNN算法:优点:精度高,对异常值不敏感,无数据输入假定.
缺点:计算复杂度高,空间复杂度高,无法给出任何数据基础结构信息.
适用范围:数值型和标称型.
另外代码写的贼搓,变量名啥随便取的,各位随便看看吧.
#coding:utf-8
import operator
file = open('semeion.data','r')
file.seek(0,2)
size = file.tell()
print("文件长:",size)
file.seek(0,0)
str = file.read(size)
strl = str.split('\n')
#print(strl[0])
#print(len(strl))
longs = len(strl) - 1
num = []
for i in range(0,longs):
strl1 = strl[i][-21:-2]
strl2= strl1.split(' ')
#print(strl2)
for k in range(0,10):
if strl2[k] == '1':
num.append(k)
break
snum = []
for i in range(0,longs):
ss = strl[i][0:-22]
stl = ss.split(' ')
n = []
for k in range(0,len(stl)):
n.append(float(stl[k]))
snum.append(n)
#print(snum[1])
#print(len(snum))
#print(len(snum[1]))
nn = []
nnum = []
def bubble_sort(array,num):
for i in range(len(array)-1):
current_status = False
for j in range(len(array)-i-1):
if array[j] > array[j+1]:
array[j], array[j+1] = array[j+1], array[j]
num[j],num[j+1] = num[j+1],num[j]
current_status = True
if not current_status:
break
testsize = 100
nmm = testsize
error = 0
for i in range(0,0+testsize):
nnm = 800 + i * 2
o = [0,0,0,0,0,0,0,0,0,0]
max = 0
oo = 0
test = []
testnum = []
for k in range(0,800):
nm = 0
for p in range(0,256):
nm += ((snum[nnm][p] - snum[k][p])*(snum[nnm][p] - snum[k][p]))
test.append(nm**0.5)
for l in range(0,800):
testnum.append(num[l])
bubble_sort(test,testnum)
for l in range(0,3):
o[testnum[l]]+=1
#print(testnum[0:50])
#print('o:',o)
for l in range(0,10):
if o[l]>max:
max = o[l]
oo = l
nmm -= 1
print('test:',num[nnm],'-----answer:',oo,'-----ramaining:',nmm)
if num[nnm] != oo:
error += 1
error = float(error)
print('Test num:',testsize,'-----error num:',error,'-----',error/testsize*100,'%')