最近在学习一些python的技术知识和机器学习的相关知识,本章是十大算法的起始篇,K-近邻算法,让我们一起走进机器学习的世界把。
笔者本身是计算机相关专业,机器学习算法在数据挖掘的学习过程中已经接触过了,但是之前的学习是理论知识方面的学习,说实在之前学习也不知道学了有什么用处。等到现在用到了才发现之前也没有好好认真的学习,现在就从头学习把。
先用书本上的栗子来引出为什么学习这个算法。
众所周知,电影可以按照题材来分类,然而题材本身是如何定义的呢?我们根据什么来判断这部电影是属于哪个题材呢?
比如说动作片与爱情片,爱情片中很多也存在动作的镜头,而动作片中也有很多接吻的镜头等。如何判断电影是属于爱情片还是动作片呢?
相对来说肯定是动作片里面的动作镜头更加的多,然后爱情片里面的爱情戏份更加的多,我们就通过判断电影里面相对占的次数比较多来判断属于哪种类型的电影。
简单的说,K-近邻算法就是采用测量不同特征值之间的距离进行分类。 通常K是不大于20的整数。
详细的描述就是:在已有一个样本数据集的情况下,我们已知每个样本对应的所属分类,然后输入新的数据样本,将新的数据样本特征值与已有的数据样本的特征值进行比较,然后得到新的数据样本属于数据集的哪个分类。
我们回归到电影分类的栗子,使用 K-近邻算法分类爱情片和动作片:
下图显示了研究的6部电影的打斗和接吻镜头:
图中的?号所在的位置是该位置电影出现的镜头数图形化展示,具体数字参见下表:
电影名称 | 打斗镜头 | 接吻镜头 | 电影类型 |
---|---|---|---|
California Man | 3 | 104 | 爱情片 |
He’s Not Really into Dudes | 2 | 100 | 爱情片 |
Beautiful Woman | 1 | 81 | 爱情片 |
Kevin Longblade | 101 | 10 | 爱情片 |
Robo Slayer 3000 | 99 | 5 | 爱情片 |
Amped II | 98 | 2 | 爱情片 |
? | 18 | 90 | 爱情片 |
即使不知道电影属于哪种类型,我们也可以通过某种方法计算出来,首先计算未知电影与样本中其他电影的距离,得到如下表的数据,暂时不关心如何计算得到。
电影名称 | 与未知电影的距离 |
---|---|
California Man | 20.5 |
He’s Not Really into Dudes | 18.7 |
Beautiful Woman | 19.2 |
Kevin Longblade | 115.3 |
Robo Slayer 3000 | 117.4 |
Amped II | 118.9 |
现在我们得到了样本集中所有的电影与未知电影的距离,按照距离递增排序,可以找到K个距离最近的电影。假定K=3,则得到三个电影分别是California Man,He’s Not Really into Dudes ,Beautiful Woman。这三部电影都是爱情片,所以我们的结论为未知电影属于爱情片。
接下来我们将使用python来开发K-近邻算法:
K-近邻算法的一般流程:
1.收集数据:可以使用任何方法。
2.准备数据:距离计算所需要的数值,最好是结构化的数据格式。
3.分析数据:可以使用任何方法。
4.训练算法:此步骤不适用于K-近邻算法。
5.测试算法:计算错误率。
6.使用算法:首先需要输入样本数据和结构化的输出结果,然后运行K-近邻算法判断输入的新样本数据属于哪个分类,最后应用对计算出的分类执行后续的处理。
####python实现开始
######1.使用python导入数据
我们已经准备好了一个名为kNN.py的文件,可以在kNN.py网址中选择源代码进行下载,也可以跟随本教程一步一步自己来理解并实现这个算法。
#######2.实施kNN分类算法
我们先来看一下使用伪代码表述的算法:
对未知类别属性的数据集中的每个点依次执行以下操作:
1.计算已知类别数据集中的每个点与当前点之间的距离。(常用欧几里得距离或曼哈顿距离)
2.按照距离递增次序排序。
3.选取与当前点次序最小的K个点。
4.确定前K个点所在类别的出现概率。
5.返回前K个点出现频率最高的类别作为当前点的的预测分类。
import numpy as np
import operator
# k-近邻算法
# 距离计算采用欧几里得距离
# inX 用于分类的输入向量
# dataSet 输入的训练样本集
# labels 标签向量
# K 选择的邻近样本数量
def classify(inX, dataSet, labels, K):
#得到数据集的行数 shape方法用来得到矩阵或数组的维数 shape[0]表示行数 shape[1]表示列数
dataSetSize = dataSet.shape[0]
#tile:numpy中的函数。tile将原来的一个数组复制成一个新的数组
# a = [1,2]
# tile(a,2) [1 2 1 2]
# tile(a,(1,2)) [[1 2 1 2]] 扩展列数
# tile(a,(2,1)) [[1 2][1 2]] 扩展行数
#diffMat 就是将输入向量扩展成与dataSet相同行数的新向量 - dataSet 得到的结果其实就是x1-x0
diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
# 计算得到 (x1 - x0)^2
sqDiffMat = diffMat**2
# 得到平方和
# 当axis为0时,是压缩行,即将每一列的元素相加,将矩阵压缩为一行
# 当axis为1时,是压缩列,即将每一行的元素相加,将矩阵压缩为一列
sqDistances = sqDiffMat.sum(axis=1)
# 计算距离
distance = sqDistances**0.5
# 对得到的距离进行索引排序
# a [1,3,2]
# a.argsort() [0,2,1]
sortedDistIndicies = distance.argsort()
classCount = {
}
for i in range(K):
#voteIlabel的值等于给定labels[距离的索引排序]
voteIlabel = labels[sortedDistIndicies[i]]
#依次查询classCount中是否有该key,有则将取出value再+1,没有则返回添加该key并置value为0,再+1
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
# sorted函数进行临时排序 排序为从大到小进行排序 python3.5之后都是.items
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)
# 返回结果中最大的值 即是最接近的分类值
return sortedClassCount[0][0]
# 构造训练样本集
def createDataSet():
group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
group = createDataSet()[0]
labels = createDataSet()[1]
result = classify([1,1.2],group,labels,3)
print(result) # A
上述代码亲测在python3.6.3环境下可以正常的运行。但是实际上上述栗子只是现实世界中非常简单的一个栗子,接下来我们将来继续探讨K-近邻算法在现实世界中的应用。
####示例:使用K-近邻算法改进约会网站的配对效果
帮助海伦在约会网站中将他的意向匹配对象划到相对应的分类中。
######1.准备数据:从文本文件中解析数据
海伦已经收集了某些历史数据,存放在data.txt中下载地址——>源代码——>Ch02——>datingTestSet2.txt。其中每个样本数据占据一行,总共1000行,海伦的样本数据主要包含以下3个特征:
我们将创建一个file2matrix函数来将输入的文件名字符串转化为训练样本矩阵和类标签向量。
import numpy as np
import operator
# 将文本记录转化为Numpy解析的数组
def file2matrix(filename):
# open() 函数用于打开一个文件,创建一个 file 对象
fr = open(filename)
# 读取文件的每一行 readlines() 自动将文件内容分析成一个行的列表
arrayOLines = fr.readlines()
# 得到所有数据的条数
numberOfLines = len(arrayOLines)
# 得到用0填充的数组 zeros(shape, dtype=float, order='C') 返回一个给定形状和类型的用0填充的数组
# 至于后面为什么是3列呢? 因为我们的特征主要是3大特征
# returnMat 为一个1000行 3列的0元素填充的数组
returnMat = np.zeros((numberOfLines,3))
# 控制循环与分类标签
classLabelVector = []
index = 0
# 遍历每行数据
for line in arrayOLines:
# 去除前后空格
line = line.strip()
# 对于每一行,按照制表符切割字符串,得到的结果构成一个数组,数组的每个元素代表一行中的一列
listFormLine = line.split('\t')
# 将得到的数组中的 前三列赋值给构造的数组的每一行
returnMat[index,:] = listFormLine[0:3]
# 将训练样本数据集的分类全部添加进classLabelVector
classLabelVector.append(int(listFormLine[-1]))
index += 1
return returnMat, classLabelVector
print(file2matrix('data.txt'))
######2.使用matplotlib来创建散点图
在我们得到了需要的数据格式之后,我们将使用matplotlib来进行散点图的绘制,观察数据的分布情况。
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties
font = FontProperties(fname=r"/System/Library/Fonts/PingFang.ttc", size=14)
# 这个引入的就是上方编写的那段代码
from test2 import file2matrix,datingDataMat,datingLabels
# 新建一个名叫 fig的画图窗口
fig = plt.figure()
# 参数111的意思是:将画布分割成1行1列,图像画在从左到右从上到下的第1块
ax = fig.add_subplot(1,1,1)
# 构造横纵坐标
x = datingDataMat[:,1] # type ndarray
y = datingDataMat[:,2] # type ndarray
# 匹配参数
ax.scatter(
x,
y,
15.0*np.array(datingLabels),
15.0*np.array(datingLabels)
)
plt.title(u'游戏与冰淇淋散点图', FontProperties=font)
plt.xlabel(u'玩视频游戏所耗时间百分比', FontProperties=font)
plt.ylabel(u'每周消费的冰淇淋公升数', FontProperties=font, verticalalignment='baseline',horizontalalignment='center', rotation='vertical')
plt.show()
我们再来看一下另外两个参数的比较
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties
font = FontProperties(fname=r"/System/Library/Fonts/PingFang.ttc", size=14)
from test2 import file2matrix, datingDataMat, datingLabels
# 设置图例的显示中文
plt.rcParams['axes.unicode_minus'] = False
# 新建一个名叫 fig的画图窗口
fig = plt.figure()
# 参数111的意思是:将画布分割成1行1列,图像画在从左到右从上到下的第1块
ax = fig.add_subplot(1, 1, 1)
# 一共有3类,所以定义3个空列表准备接受数据
type1_x = []
type1_y = []
type2_x = []
type2_y = []
type3_x = []
type3_y = []
# 1000组数据,i循环1000次
for i in range(len(datingLabels)):
# 根据标签进行数据分类
if datingLabels[i] == 1:
print(datingDataMat[i][0])
type1_x.append(datingDataMat[i][0]) # 取的是样本数据的第一列特征和第二列特征
type1_y.append(datingDataMat[i][1])
if datingLabels[i] == 2:
type2_x.append(datingDataMat[i][0])
type2_y.append(datingDataMat[i][1])
if datingLabels[i] == 3:
type3_x.append(datingDataMat[i][0])
type3_y.append(datingDataMat[i][1])
ax.scatter(type1_x, type1_y, s=20, c='r', label=u'不喜欢')
ax.scatter(type2_x, type2_y, s=20, c='b', label=u'魅力一般')
ax.scatter(type3_x, type3_y, s=20, c='k', label=u'极具魅力')
plt.title(u'飞行里程数与玩游戏散点图', FontProperties=font)
plt.xlabel(u'每年获取的飞行常客里程数', FontProperties=font)
plt.ylabel(u'玩视频游戏所耗时间百分比', FontProperties=font, verticalalignment='baseline',
horizontalalignment='center', rotation='vertical')
# 设置图例的字体
plt.legend(prop=font)
plt.show()
######3.进行数据归一化
什么是数据归一化呢?
简单来说就是把原有的数据映射到0到1或者-1到1之间。
newValue = ( oldValue - min ) / ( max - min )
我们看下表:
玩视频游戏所耗时间百分比 | 每年获得的飞行常客里程数 | 每周消耗的冰淇淋公升数 | 样本分类
比如我们计算样本3和样本4之间的距离,发现每年获取的飞行常客里程数对于计算结果远远大于其他的两个特征,然而这样的结果是不太公平的,所以我们将对数据进行归一。
import numpy as np
def autoNorm(dataSet):
# 在每一列中选取最小值,而不是选取当前行的最小值
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = np.zeros(np.shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - np.tile(minVals, (m,1))
normDataSet = normDataSet/np.tile(ranges, (m,1))
return normDataSet, ranges, minVals
在进行归一化数据后,我们将进行检测我们的算法错误率。
import numpy as np
from test2 import datingDataMat, datingLabels, autoNorm, classify
def datingClassTest():
hoRatio = 0.1
# 归一化后的原始数据
normMat, ranges, minVals = autoNorm(datingDataMat)
# 原始数据总共1000行
m = normMat.shape[0]
# 使用10%的数据作为测试数据 100条记录
numTestVecs = int(m*hoRatio)
# 设置初始错误计数
errorCount = 0.0
# 0 - 1000
for i in range(numTestVecs):
print(str(i),normMat[i,:])
classifierResult = classify(normMat[i,:],normMat[numTestVecs:m,:], datingLabels[numTestVecs:m], 3)
print("the result sort is:%d,the real answer is:%d" %(classifierResult, datingLabels[i]))
if(classifierResult != datingLabels[i]):errorCount += 1.0
print("the total error rate is:%f"%(errorCount/float(numTestVecs)))
datingClassTest()
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:3
# the result sort is:1,the real answer is:1
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:3
# the result sort is:3,the real answer is:3
# the result sort is:2,the real answer is:2
# the result sort is:1,the real answer is:1
# the result sort is:3,the real answer is:1
# the total error rate is:0.050000
最后我们来模拟一下真实的筛选现场,我们把打分的输入交给用户输入,然后根据输入判断是否是海伦喜欢的类型:
def classfiyPerson():
resultList = ['一点也不喜欢','有一点喜欢','愿意进一步交往']
percentTats = float(input("花费在打游戏的时间比例?"))
ffMiles = float(input("每年获取的飞行常客里程数?"))
iceCream = float(input("每年消费多少公升的冰淇淋?"))
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles, percentTats, iceCream])
classifierResult = classify((inArr - minVals) / ranges,normMat,datingLabels, 3)
print("the result is:",resultList[classifierResult - 1])
classfiyPerson()
####手写识别系统
本节我们继续学习使用K-近邻算法实现一个手写识别算法。简单起见,这里的构造系统只能识别数字0-9,需要识别的数字已经使用图形软件处理成具有相同的色彩和大小,32*32。
K-近邻算法实现手写识别系统:
1.收集数据:提供文本文件。
2.准备数据:编写img2vector,将图像格式转化为分类器使用的向量格式。
3.分析数据:检查数据是否符合要求。
4.训练算法:此步骤不适用于K-近邻算法。
5.测试算法:计算错误率。
6.使用算法:首先需要输入样本数据和结构化的输出结果,然后运行算法判断输入的新样本数据属于哪个数字,最后应用对计算出的分类执行后续的处理。
######准备数据,将图像转化为向量
我们构造一个方法,将输入的图像转化为一个向量
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
######测试算法,使用K-近邻算法识别手写数字
这里用到的数据同样可以在上述地址进行下载。
# 测试识别手写数字
def handwritingClassTest():
# 正确的结果集
hwLabels = []
# 使用listdir方法可以列出所给定目录的文件名
trainingFileList = listdir('trainingDigits')
# 1934
m = len(trainingFileList)
print(m)
# np.zeros((1934,1024)) 1934行1024列数组
trainingMat = np.zeros((m,1024))
for i in range(m):
# 获取完整文件名
fileNameStr = trainingFileList[i]
# 截取文件名
fileStr = fileNameStr.split('.')[0]
# 获取数字实际结果 我们的文件名类似长这样 0_1.txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
# 构造转化后的数据
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
# 构造测试数据
testFileList = listdir('testDigits')
# 初始错误计数
errCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
# 获得预测结果
classifierResult = classify(vectorUnderTest,trainingMat,hwLabels,3)
print("the result is:%d ,the real answer is:%d" % (classifierResult, classNumStr))
if(classifierResult != classNumStr):errCount += 1.0
# 打印错误计数和错误率
print("the total error count is:%d" % errCount)
print("the error rate is:%f" % (errCount/float(mTest)))
handwritingClassTest()
到目前为止,K-近邻算法介绍基本结束了,总结一下K-近邻算法。
K-近邻算法是一种最简单有效的算法,但是使用该算法的时候,我们必须要有足够量的训练数据,如果训练数据集比较大,我们就要消耗大量的存储空间,而且由于算法需要对每个数据计算距离值,实际效率可能不是很好。
有没有什么更好的解决办法呢?下章我们将探讨k决策树,一种K-近邻算法的优化版本。后面如果还有示例,会追加到尾部。