K-近邻算法是最简单的分类算法之一,通过测量不同特征之间的距离进行分类,如果一个样本在特征空间中k个邻近的样本中大多数属于某一类别,则该样本也属于这个类别。kNN中一般用欧式距离作为各个对象之间的非相似性指标:d(x,y):=√∑ni=1(xi−yi)2,也可以使用马氏距离测量距离方法。当训练集、最近邻值K、距离度量、决策规则等确定下来时,算法实际上是把特征空间划分成一个个空间,训练集中每个样本占据一个空间。K-邻近算法一般流程如下。
计算测试数据和训练数据之间的距离
按照距离的递增关系进行排序
选取距离最小的k个点
确定前k个点所在类别的出现频率
返回前k个点中出现频率最高的类别作为测试数据和预测数据
k的选择要在方差和偏差之间取得平衡,若k取值很小,容易因为噪声点和干扰点出现错误,此时方差较大;若K取值很大,预测和真实相差太远,偏差过大。通常利用交叉验证评估一系列不同的k值,选取最好的k作为训练参数。
简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归;
可用于数值型数据和离散型数据;
训练时间复杂度为O(n);无数据输入假定;
对异常值不敏感。
计算复杂性高;空间复杂性高;
样本不平衡问题(即有些类别的样本数量很多,而其它样本的数量很少);
一般数值很大的时候不用这个,计算量太大。但是单个样本又不能太少,否则容易发生误分。
最大的缺点是无法给出数据的内在含义。
补充一点:由于它属于懒惰学习,因此需要大量的空间来存储训练实例,在预测时它还需要与已知所有实例进行比较,增大了计算量。
海伦约会网站上将人分为三类
不喜欢的人=1
魅力一般的人=2
极具魅力的人=3
并且在网站上收集了一些数据,记录了一个人三个指标
每年飞行里程数
玩电子游戏所占百分比
每周吃冰淇淋的公升数
数据样本到下列网页中,找到随书下载源代码ch02中
https://www.ituring.com.cn/book/1021
(1)收集数据:提供文本文件。
(2)准备数据: 使用python解析文本文件。
(3)分析数据:使用matplotlib画二维扩散图。
(4)训练算法:此步驟不适用于k-近邻算法。
(5)测试算法:使用海伦提供的部分数据作为测试样本。
测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
(6)使用算法:产生简单的命令行程序,然后海伦可以输入一些特征数据以判断对方是否 为自己喜欢的类型。
# -*- coding:utf-8 -*-
# __author__ = "LQ"
import operator
import numpy as np
from os import listdir
import matplotlib
import matplotlib.pylab as plt
# 欧式距离计算函数
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#tile函数的作用是让某个数组(其实不局限于数组,但我们这里只讨论数组),
# 以某种方式重复,构造出新的数组,所以返回值也是个数组。
#下面是让预测值重复出和训练数组一样行数的数组,然后对应相减
diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
#按行就和sum(axis=1),axis=0时按列求和
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
#从小到大排序
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
#根据key从数组中获取值,获取不到初始值0,然后累加1,实现统计作用
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
#sorted(iterable, key=None, reverse=False)
#参数说明:
#iterable - - 可迭代对象。
#key - - 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,
#指定可迭代对象中的一个元素来进行排序。
#reverse - - 排序规则,reverse = True 降序 , reverse = False 升序(默认)。
#operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为一些序号(即需要获取的数据在对象中的序号)
#operator.itemgetter(1)操作元素的value作为排序,降序
#items()以列表形式返回可遍历的(键, 值)元组数组
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
#数据归一化处理
def autoNormal(dataSet):
minVals = dataSet.min(0) # 参数0 按列查找最小值 这样可以找出三个特征值的最小值了
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = np.zeros(np.shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - np.tile(minVals, (m, 1))
# tile(a,reps): a :要复制的值 reps:复制的次数 本次案例中 (m,1) m:复制m行 1:复制1列
normDataSet = normDataSet / np.tile(ranges, (m, 1)) # element wise divide
return normDataSet, ranges, minVals
#读取文本数据
def file2matrix(filename):
# python 自带的open函数 打开我们要解析的数据集
fr = open(filename)
numberOfLines = len(fr.readlines()) # 得到文件行数
# 该数据集有 三个特征值 一个目标向量
returnMat = np.zeros((numberOfLines, 3)) # 创建一个和文件同行数 3列的矩阵(元素为0)
classLabelVector = []
index = 0
fr = open(filename) # 这里要注意 我们再一次打开了文件
# 原因是文件再打开后被使用了一次后自动关闭了 导致我们后面继续读取文件时 读取不到信息 就无法正常操作了
for line in fr.readlines(): # 解析文件数据到列表
line = line.strip() # strip():移除字符串头尾指定的字符序列 无参数为空格和回车字符
# str = "123abcrunoob321"
# print (str.strip( '12' )) # 字符序列为 12
# 首尾的12 都被移除了 顺序不重要
listFromLine = line.split('\t') # split 切分数据 返回值 列表
print("\t分割后的列表:", listFromLine) # 分割后的列表: ['40920', '8.326976', '0.953952', 'largeDoses']
returnMat[index, :] = listFromLine[0:3] # 提取前三个特征值赋值给矩阵
print("第", index, "行得到的列表", returnMat[index, :]) # 第 0 行得到的列表 [4.092000e+04 8.326976e+00 9.539520e-01]
classLabelVector.append(listFromLine[-1]) # 提取位于尾末的目标向量(标签信息)
index += 1
return returnMat, classLabelVector
# 测试分类器
def datingClassTest(filename):
hoRatio = 0.1 # 随机取数据集的10% 海伦提供的数据集本身是无序的 所以这里直接取10%
datingDataMat,datingLabels = file2matrix(filename) # 加载原始数据
# 查看散点图
printData(datingDataMat)
normMat, ranges, minVals = autoNormal(datingDataMat) # 加载归一化数据
m = normMat.shape[0]
numTestVecs = int(m*hoRatio) # 取得10%数据
errorCount = 0.0 # 错误预测数初始化
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print("the classifier came back with: %s, the real answer is: %s" % (classifierResult, datingLabels[i]))
# the classifier came back with: largeDoses, the real answer is: largeDoses
# the classifier came back with: smallDoses, the real answer is: smallDoses
# the classifier came back with: didntLike, the real answer is: didntLike
if (classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %d" % ((errorCount/float(numTestVecs))*100), "%")
print(errorCount)
# the total error rate is: 5 %
# 5.0
# 2.分析数据 散点图函数封装
# 根据画出的散点图 我们可以自己直观看到 喜欢程度的分布
def printData(data):
# 我们用散点图 直观显示 特征值的数据分布 三个特征 我们可以画三个子图 这里我们就用一个作案例了
# 用matplotlib创建大图fig 设置字体 黑体
plt.rcParams['font.sans-serif'] = ['Simhei']
fig = plt.figure()
# 定义ax为大图中的子图 111:一行一列第一个
ax = fig.add_subplot(111)
# 画ax子图 坐标轴为特征值列表 索引为1和2的特征
plt.scatter(data[:, 1], data[:, 2])
plt.xlabel("玩游戏所占时间")
plt.ylabel("每周消费冰激凌数")
plt.show()
if __name__ == '__main__':
datingClassTest("datingTestSet.txt")