首先建立一个模块KNN.py,写一个生成数据的函数
from numpy import *
import operator
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
在python shell 中要调用这个函数,首先要到该模块所在路径,使用os模块中的函数:
import os
查看当前路径 os.getcwd()
更改当前路径 os.chdir()
在所在路径下先import KNN,再使用group, labels = KNN.createDataSet()即可生成数据。
:
1.计算当前点和训练集中的每个点的欧氏距离
2.从小到大排序后取训练集中前k个点
3.返回这些点中出现频率最高的
python实现:
def classify0(inX,dataSet,labels,k):
size = dataSet.shape[0]
tmp = tile(inX,(size,1)) - dataSet
tmp **= 2
tmp = tmp.sum(axis = 1)
tmp **= 0.5
indice = tmp.argsort()
count = {}
for i in range(k):
lb = labels[indice[i]]
count[lb] = count.get(lb,0) + 1
sortedCount = sorted(count.iteritems(),key = operator.itemgetter(1), reverse = True)
return sortedCount[0][0]
语法解析:
1.shape返回array大小,shape[0]为第一维大小(训练集数据数量)
2.tile(A,reg):把A按照reg的形式复制,即:reg是一个矩阵,把矩阵中的每个元素用A替代就是最后结果。例子:
>>> a = np.array([0, 1, 2])
>>> np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
>>> np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],
[0, 1, 2, 0, 1, 2]])
>>> np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],
[[0, 1, 2, 0, 1, 2]]])
>>> b = np.array([[1, 2], [3, 4]])
>>> np.tile(b, 2)
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
3.argsort(),返回排序后的下标array
4.字典dict.get(key,x)查找键为key的value,如果不存在返回x。(用来初始化很方便)
5.sorted只能用于可迭代类型,因此字典必须用dict.iteritems()返回迭代器。sorted返回结果为list。
6.operator.itemgetter(i)返回对象的第i+1个元素,相当于匿名函数。
5,6的详解可以看http://www.cnblogs.com/100thMountain/p/4719503.html
def file2matrix(filename):
fr = open(filename)
classVector = []
data = fr.readlines()
lineNumber = len(data)
resMat = zeros((lineNumber,3))
index = 0
for line in data:
line = line.strip()
lineList = line.split(' ')
resMat[index,:] = lineList[0:3]
classVector.append(int(lineList[-1]))
index += 1
return resMat, classVector
语法解析:
1.readlines从一个文件中逐行读入,并存入一个list中
2.str.strip():去掉首尾的指定字符,若没有则去掉首尾空格
3.str.split():以指定字符为分割符分割字符串,不指定则为空格
4.把字符串型数字赋值给numpy.array时会自动转成数字类型
def plot(datingDataMat):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2])
plt.show()
语法解析:
1.figure 创建一张新的图像
2.add_subplot(111) 表示把图像分割为1行1列,当前子图像画在第1块
3.scatter(X,Y) 以X为x坐标,Y为y坐标绘制散点图
benchmark来自UCI,下载地址:
https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit
每个数字由32*32的灰度图表示,即256个0,1值组成的向量,后面紧跟一个大小为10的label向量
def readData(filename):
fr = open(filename)
data = fr.readlines()
fr.close()
n = len(data)
dataMat = zeros((n,256))
dataLabels = []
for i in range(n):
vals = data[i].split()
dataMat[i,:] = vals[0:256]
dataLabels.append(vals[256:].index('1'))
return dataMat, dataLabels
def test():
dataMat, dataLabels = readData(r"E:\MLData\handwritedNumbers\semeion.data")
m = dataMat.shape[0]
testnum = int(m * 0.1)
wrongnum = 0
for i in range(testnum):
ans = classify0(dataMat[i,:], dataMat[testnum:,:], dataLabels[testnum:], 5)
if (ans != dataLabels[i]): wrongnum += 1
print double(wrongnum) / testnum
list.index(n)返回list中第一个n的下标
最终发现预测错误率只有3%。(这么简单的算法竟然效果这么好,感受到了ML的强大)。
最后附上k值为1~20内的错误率:
1 0.0440251572327
2 0.062893081761
3 0.0440251572327
4 0.0503144654088
5 0.0314465408805
6 0.0503144654088
7 0.0440251572327
8 0.0377358490566
9 0.0377358490566
10 0.0503144654088
11 0.0377358490566
12 0.0377358490566
13 0.0440251572327
14 0.0440251572327
15 0.0440251572327
16 0.0377358490566
17 0.0440251572327
18 0.0440251572327
19 0.0377358490566
可以发现并没有明显规律