机器学习实战系列之学习笔记主要是本人进行学习机器学习的整理。本系列所有代码是用python3编写,并使用IDE Pycharm在Windows平台上编译通过。本系列所涉及的所有代码和资料可在我的github或者码云上下载到,gitbub地址:https://github.com/mcyJacky/MachineLearning,码云地址:https://gitee.com/mcyHome/MachineLearning,如有问题,欢迎指出~。
k近邻法(k-nearest neighbor, K-NN)是一种基本分类和回归方法,在1968年由Cover和Hart提出。它的工作原理是:存在一个样本数据集合,也称做训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后用算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类(多数表决),作为新数据的分类。
例如我们用KNN算法根据电影打斗镜头和接吻镜头来预测分类该影片是爱情片还是动作片,首先我们要知道我们预测的电影存在多少打斗镜头和接吻镜头,如下表1.1问号位置是该电影的预测结果:
电影名称 | 打斗镜头 | 接吻镜头 | 电影类型 |
---|---|---|---|
A电影 | 3 | 104 | 爱情片 |
B电影 | 2 | 100 | 爱情片 |
C电影 | 1 | 81 | 爱情片 |
D电影 | 101 | 10 | 动作片 |
E电影 | 99 | 5 | 动作片 |
F电影 | 18 | 90 | ? |
如表1.1所示,我们即使不知道F电影属于什么类型,我们也可以通过KNN方法计算出来,首先要计算F电影与样本集中其它电影的距离,我们先不关心如何计算距离,具体会在下面描述。如表1.2是F电影与其它电影的距离计算结果。
电影名称 | 与F电影的距离 |
---|---|
A电影 | 20.5 |
B电影 | 18.7 |
C电影 | 19.2 |
D电影 | 115.3 |
E电影 | 117.4 |
如表1.2所示,样本集中所有电影与F电影的距离,按照距离递增排序,可以找到k个距离最近的电影。假设k=3,则最靠前的是A电影、B电影、C电影。KNN算法按照距离最近的三部电影的类型通过多数表决方法决定,而这三部电影都是爱情片,因此我们预测F电影也是爱情片。
特征空间中两个实例点的距离是两个实例点相似程度的反映。KNN模型的特征空间一般是 n n 维实数向量的空间向量Rn R n 。使用的是欧式距离,但也有其它距离。
假设特征向量 xi,xj x i , x j 分别为 xi=(x(1)i,x(1)i,...,x(n)i)T x i = ( x i ( 1 ) , x i ( 1 ) , . . . , x i ( n ) ) T , xj=(x(1)j,x(1)j,...,x(n)j)T x j = ( x j ( 1 ) , x j ( 1 ) , . . . , x j ( n ) ) T ,则 xi,xj x i , x j 的距离为 Lp L p 距离定义为:
使用pycharm创建KNN.py文件,使用numpy库编写通用的函数来创建数据集和标签:
import numpy as np
'''
#创建数据集和标签
#param: none
#return:
array 数据集
list 标签
'''
def createDataSet():
group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0,0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
if __name__ == '__main__':
group, labels = createDataSet()
print(group)
print('--------')
print(labels)
'''输出结果:
[[1. 1.1]
[1. 1. ]
[0. 0. ]
[0. 0.1]]
--------
['A', 'A', 'B', 'B']
'''
上述矩阵group中有4组数据,每组数据有两个我们已知的属性或者特征值。向量标签labels包含了每个数据点的标签信息,labels包含的元素个数为group矩阵的行数。
import numpy as np
import operator
import time
'''
#创建数据集和标签
#param: none
#return:
array 数据集
list 标签
'''
def createDataSet():
group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0,0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
'''
#功能:分类器
#param:
inX 输入向量
dataSet 训练样本集
labels 标签向量
k 最近邻数目
#return: 分类标签
'''
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] # 训练样本矩阵行数
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet # 将输入向量inX进行维度扩充(使维度与dataSet相同),并作矩阵之差
sqDiffMat = diffMat**2 # 将矩阵平方
sqDistances = sqDiffMat.sum(axis=1) # 将矩阵元素按行相加: sum(axis=1)按行相加,sum(axis=0)按列相加
distances = sqDistances**0.5 # 开根号,计算欧式距离
sortedDistIndices = distances.argsort() # 矩阵中的元素按从小到大进行排序后的索引值
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]] #距离最小前k个距离的标签
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 #从字典中统计标签的个数
# 将字典分解为元组列表[('A', 2), ('B', 1)...]
# 将元组列表标签个数按值value(用key=operator.itemgetter(1))进行从大到小排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
if __name__ == '__main__':
start = time.clock()
test = [1, 0.8]
group, labels = createDataSet()
test_class = classify0(test, group, labels, 3)
print(test_class)
end = time.clock()
print('Finished in', end - start)
'''输出结果:
A
Finished in 0.00019382568478496005
'''
上述classify0()函数有4个输入参数,并采用欧式距离来计算两个向量点的距离。计算完距离后,对数据进行从小到大次序排序。然后,确定前k个距离最小元素所在的主要分类。最后,将字典classCount分解为元组列表,再按第二个元素的次序对元组进行逆序排序。使用测试数据test = [1, 0.8]作为输入向量进行计算,计算结果为A。当然分类器是否会出错呢?答案是肯定的,分类器并不会得到百分百正确的结果,我们可以通过大量的测试数据,得到分类器的错误率—分类器给出错误结果的次数除以测试执行的总数。错误率是常用的评估方法,主要用于评估分类器在某个数据集上的执行效果。完美的分类器错误率为0,但通常都是出现过拟合的现象,最差分类器错误率是1.0,在这种情况下,分类器根本找不到一个正确答案。
上面的例子已经可以正常运转,但是并没有太大的实际用处,下面将会用两个具体的示例使用KNN算法改进约会网站的效果和使用KNN算法改进手写识别系统来展示现实世界中使用KNN算法。
背景:海伦一直使用在线约会网站寻找适合自己的约会对象。尽管约会网站会推荐不同的人选,但她并不喜欢每一个人。经过一番总结,她发现曾交往过这三种类型的人:
尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的类别。她觉得可以周一至周五约会那些魅力一般的人,而周末则更喜欢与那些极具魅力的人为伴。海伦希望我们的分类软件可以更好地帮助她将匹配对象划分到确切的分类中。此外海伦还收集了一些约会网站未曾记录的数据信息,她认为这些数据更有助于匹配对象的分类。
海伦进行了一段时间的约会数据收集,她把这些数据存放在datingTestSet.txt文本中,每个样本数据占一行,总共1000行。其样本数据主要包含以下3中特征:
40920 8.326976 0.953952 largeDoses
14488 7.153469 1.673904 smallDoses
26052 1.441871 0.805124 didntLike
将上述特征数据输入到分类器之前,必须将待处理数据的格式改变为分类器可以接受的格式。我们创建file2matrix()函数来处理输入格式问题:
'''
#功能:解析文本文件
#param:
fileName 文件名称txt
#return:
returnMat [matrix] 训练样本矩阵
classLabelVector [list] 类标签向量
'''
def file2Mmatrix(filename):
dict = {'didntLike':1, 'smallDoses':2, 'largeDoses':3}
fr = open(filename) #打开txt文本文件
arrayOLines = fr.readlines() #读取所有行,返回行列表
returnMat = np.zeros((len(arrayOLines), 3)) #准备返回样本矩阵
classLabelVector = [] #准备返回的类标签向量
index = 0
for line in arrayOLines:
line = line.strip() #截取头尾的所有回车字符
listFromLine = line.split('\t') #使用tab字符\t将行数据分割成一个元素列表
returnMat[index,:] = listFromLine[0:3]
if listFromLine[-1].isdigit():
classLabelVector.append(int(listFromLine[-1]))
else:
classLabelVector.append(dict.get(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
if __name__ == '__main__':
start = time.clock()
fileName = 'datingTestSet.txt'
datingDataMat, datingLabels = file2Mmatrix(fileName)
print(datingDataMat)
print(datingLabels[0:20])
end = time.clock()
print('Finished in', end - start)
'''输出结果:
[[4.0920000e+04 8.3269760e+00 9.5395200e-01]
[1.4488000e+04 7.1534690e+00 1.6739040e+00]
[2.6052000e+04 1.4418710e+00 8.0512400e-01]
...
[2.6575000e+04 1.0650102e+01 8.6662700e-01]
[4.8111000e+04 9.1345280e+00 7.2804500e-01]
[4.3757000e+04 7.8826010e+00 1.3324460e+00]]
[3, 2, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 2, 1, 1, 1, 1, 1, 2, 3]
Finished in 0.0046881340784240035
'''
现在我们已经从文本文件中导入了数据,并将其格式化为想要的格式,接着我们需要了解数据的真实含义,我们可以采用图形化的方式直观地展示数据。
我们可以在python环境中使用Matplotlib制作样本数据的散点图。
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
'''
#功能:可视化训练样本
#param:
datingDataMat 训练样本矩阵
datingLabels 类标签向量
#return: plot图
'''
def viewDatas(datingDataMat, datingLabels):
LabelColors = []
for i in datingLabels:
if i == 1:
LabelColors.append('red')
elif i == 2:
LabelColors.append('green')
elif i == 3:
LabelColors.append('blue')
fig = plt.figure(figsize=(10, 8))
ax0 = fig.add_subplot(221)
ax1 = fig.add_subplot(222)
ax2 = fig.add_subplot(212)
#散点图1:以矩阵第2列,第3列绘制图,散点大小为15,透明度为0.5
ax0.scatter(datingDataMat[:,1], datingDataMat[:,2], s=15, color=LabelColors, alpha=0.5)
# 设置标题title, 标签label
axs0_title = ax0.set_title(u"玩视频游戏所消耗时间占比与每周消费的冰激淋公升数")
axs0_xlabel = ax0.set_xlabel(u"玩视频游戏所消耗时间占比")
axs0_ylabel = ax0.set_ylabel(u"每周消费的冰激淋公升数")
# 设置相应属性
plt.setp(axs0_title, size=9, weight='bold', color='red')
plt.setp(axs0_xlabel, size=8, weight='bold', color='black')
plt.setp(axs0_ylabel, size=8, weight='bold', color='black')
# 散点图2:以矩阵第1列,第2列绘制图,散点大小为15,透明度为0.5
ax1.scatter(datingDataMat[:, 0], datingDataMat[:, 1], s=15, c=LabelColors, alpha=0.5)
# 设置标题title, 标签label
axs1_title = ax1.set_title(u"每年获得的飞行常客里程数与玩视频游戏所消耗时间占比")
axs1_xlabel = ax1.set_xlabel(u"每年获得的飞行常客里程数")
axs1_ylabel = ax1.set_ylabel(u"玩视频游戏所消耗时间占比")
# 设置相应属性
plt.setp(axs1_title, size=9, weight='bold', color='red')
plt.setp(axs1_xlabel, size=8, weight='bold', color='black')
plt.setp(axs1_ylabel, size=8, weight='bold', color='black')
# 散点图3:以矩阵第1列,第3列绘制图,散点大小为15,透明度为0.5
ax2.scatter(datingDataMat[:, 0], datingDataMat[:, 2], s=15, c=LabelColors, alpha=0.5)
# 设置标题title, 标签label
axs2_title = ax2.set_title(u"每年获得的飞行常客里程数与每周消费的冰激淋公升数")
axs2_xlabel = ax2.set_xlabel(u"每年获得的飞行常客里程数")
axs2_ylabel = ax2.set_ylabel(u"每周消费的冰激淋公升数")
# 设置相应属性
plt.setp(axs2_title, size=9, weight='bold', color='red')
plt.setp(axs2_xlabel, size=8, weight='bold', color='black')
plt.setp(axs2_ylabel, size=8, weight='bold', color='black')
# 设置图例
didntLike = mlines.Line2D([], [], color='red', marker='.',
markersize=6, label=u'不喜欢')
smallDoses = mlines.Line2D([], [], color='green', marker='.',
markersize=6, label=u'魅力一般')
largeDoses = mlines.Line2D([], [], color='blue', marker='.',
markersize=6, label=u'极具魅力')
# 添加图例
ax0.legend(handles=[didntLike, smallDoses, largeDoses])
ax1.legend(handles=[didntLike, smallDoses, largeDoses])
ax2.legend(handles=[didntLike, smallDoses, largeDoses])
plt.show()
if __name__ == '__main__':
start = time.clock()
fileName = 'datingTestSet.txt'
datingDataMat, datingLabels = file2Mmatrix(fileName)
viewDatas(datingDataMat, datingLabels)
end = time.clock()
print('Finished in', end - start)
如下表2.1为训练样本部分数据:
序号 | 完视频游戏所消耗时间百分比 | 每年获得的飞行常客里程数 | 每周消费的冰淇淋公升数 | 样本分类 |
---|---|---|---|---|
1 | 0.8 | 400 | 0.5 | didntLike |
2 | 12 | 134000 | 0.9 | largeDoses |
3 | 0 | 2000 | 1.1 | smallDoses |
4 | 67 | 32000 | 0.1 | smallDoses |
如果我们用欧式距离计算样本3和样本4之间的距离,我们会通过如下公式:
'''
#功能:归一化数值
#param:
dataSet 训练样本矩阵
#return:
normDataSet 归一化矩阵
ranges 训练样本极值之差矩阵
minVals 训练样本组成最小值矩阵
'''
def autoNorm(dataSet):
minVals = dataSet.min(0) #选取列最小值
maxVals = dataSet.max(0) #选取列最大值
ranges = maxVals - minVals #最大值与最小值之差
normDataSet = np.zeros(np.shape(dataSet)) #定义归一化零矩阵
m = dataSet.shape[0] #行数
normDataSet = dataSet - np.tile(minVals, (m, 1)) #原样本矩阵减去矩阵列中的最小值
normDataSet = normDataSet/np.tile(ranges, (m, 1)) #矩阵归一化
return normDataSet, ranges, minVals
if __name__ == '__main__':
start = time.clock()
fileName = 'datingTestSet.txt'
datingDataMat, datingLabels = file2Mmatrix(fileName)
normMat, ranges, minVals = autoNorm(datingDataMat)
print(normMat)
print('--------')
print(ranges)
print('--------')
print(minVals)
end = time.clock()
print('Finished in', end - start)
'''输出结果:
[[0.44832535 0.39805139 0.56233353]
[0.15873259 0.34195467 0.98724416]
[0.28542943 0.06892523 0.47449629]
...
[0.29115949 0.50910294 0.51079493]
[0.52711097 0.43665451 0.4290048 ]
[0.47940793 0.3768091 0.78571804]]
--------
[9.1273000e+04 2.0919349e+01 1.6943610e+00]
--------
[0. 0. 0.001156]
Finished in 0.005247109984036027
'''
现在我们就已经通过autoNorm函数对原样本矩阵进行了归一化,下面我们将取值范围和最小值归一化测试数据。
上述我们已经将数据按照需求进行了处理,现在我们将测试分类器的效果,如果分类器的正确率满足要求,海伦就可以使用这个软件来处理约会网站提供的约会名单了。机器学习算法的一个很重要的工作就是评估算法的正确率,通常我们只提供已有数据的90%作为训练样本来训练分类器,而使用其余10%数据去测试分类器,检测分类器的正确率。接下来我们就使用这种原始方法来进行分类器的测试。我们通过构造datingClassTest()函数进行:
'''
#功能:分类器算法测试
#param: None
#return: 测试结果
'''
def datingClassTest():
hoRatio = 0.10 #测试数据为样本的10%
fileName = 'datingTestSet.txt'
datingDataMat, datingLabels = file2Mmatrix(fileName) #样本格式转换
normMat, ranges, minVals = autoNorm(datingDataMat) #样本数值归一化
m = normMat.shape[0] #行数
numTestVecs = int(m*hoRatio) #测试样本的个数
errorCount = 0.0 #统计错误结果个数
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabels[numTestVecs:m], 3)
print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, datingLabels[i]))
if classifierResult != datingLabels[i]: errorCount += 1
print('the total error rate is: %f' % (errorCount/float(numTestVecs)))
if __name__ == '__main__':
start = time.clock()
datingClassTest()
end = time.clock()
print('Finished in', end - start)
'''输出结果:
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 1, the real answer is: 1
the classifier came back with: 2, the real answer is: 2
the classifier came back with: 1, the real answer is: 1
...
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 2, the real answer is: 2
the classifier came back with: 1, the real answer is: 1
the classifier came back with: 3, the real answer is: 1
the total error rate is: 0.050000
Finished in 0.024232552925070386
'''
通过计算结果可以看出,分类器处理约会数据集的错误率为5%,当然我们可以通过改变函数内的hoRatio和变量k值来检测错误率的变化。此时5%的错误率,相对来说还算是不错的结果,海伦完全可以输入未知对象的属性信息,由分类软件来帮组她判定某一对象的可交往程度:讨厌、一般喜欢、非常喜欢。
现在我们终于可以使用这个分类器为海伦工作,我们通过编写一段小程序让海伦会在约会网站上找到某个人并输入他的信息。程序会给她对对方喜欢程序的预测值。
'''
#功能:约会网站预测函数
#param: None
#return: 预测结果
'''
def classifyPerson():
resultList = ['讨厌', '一般喜欢', '非常喜欢']
precentTats = float(input("玩视频游戏消耗时间百分比:"))
ffMiles = float(input("每年获得飞行的常客里程数:"))
iceCream = float(input("每周消耗的冰淇淋公升数:"))
datingDataMat, datingLabels = file2Mmatrix("datingTestSet.txt")
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles, precentTats, iceCream]) #输入向量
normInArr = (inArr - minVals)/ranges #输入向量归一化数值
classifierResult = classify0(normInArr, normMat, datingLabels, 3)
print('u probaly like this person: ' + resultList[classifierResult - 1])
if __name__ == '__main__':
start = time.clock()
classifyPerson()
end = time.clock()
print('Finished in', end - start)
'''输出结果:
玩视频游戏消耗时间百分比:10
每年获得飞行的常客里程数:10000
每周消耗的冰淇淋公升数:0.5
u probaly like this person: 一般喜欢
Finished in 6.558792343609121
'''
到目前为止,我们已经完成了用KNN算法改进约会网站的配对效果。当然这里的所有数据让人会直观的看起来很容易,下一个示例我们会看到如何在二进制存储图像数据上使用KNN。
背景:本示例构造KNN分类器来识别数字0至9,需要识别的数字已经使用图形处理软件,处理成具有相同系统色彩和大小:宽高是32像素×32像素的黑白图像,如下图3.1为数字0的格式。
目前我们在trainingDigits文件中有如图3.1的数字集例子大约2000个,每个例子包括从0-9的数字大约有200个样本;在testDigits文件中包含大约900个测试数据。在这里我们将重用之前创建的分类器,首先需要将32×32的二进制矩阵转换为1×1024(32*32=1024)的向量。
我们首先创建img2Vecto函数,将图像转换为向量:将32*32矩阵中的每个特征值存储在1*1024的数组中作为训练样本的特征。
'''
#功能:将图片矩阵格式(32,32)转化为向量格式(1,1024)
#param:
filename 文件名
#return:
returnVect 图片格式数组
'''
def img2Vector(filename):
returnVect = np.zeros((1,1024))
fr = open(filename)
for i in range(32):
listStr = fr.readline()
for j in range(32):
returnVect[0, 32*i + j] = int(listStr[j])
return returnVect
if __name__ == '__main__':
start = time.clock()
testVect = img2Vector('0_0.txt')
print(testVect[0, 0:20])
end = time.clock()
print('Finished in', end - start)
'''输出结果:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.]
Finished in 0.0012434845357894588
'''
上面我们已经将数据处理成分类器可以识别的格式,现在我们要将这些数据输入到分类器中,检测分类器的执行效果。函数handwritingClassTest()就是测试分类器的代码:
'''
#功能:手写数字识别系统的测试代码
#param: Nonee
#return: 测试结果
'''
def handwritingClassTest():
hwLabels = []
trainingFileList = os.listdir('./trainingDigits') #获取文件夹下的文件
m = len(trainingFileList) #文件的个数
trainingMat = np.zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i] #第i个文件名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])#将文件按'_'分割,获取矩阵实际标签
hwLabels.append(classNumStr) #添加标签
trainingMat[i,:] = img2Vector('./trainingDigits/' + fileNameStr) #格式转换
testFileList = os.listdir('./testDigits') #测试文件
errorCount = 0.0 #预测错误个数
mTest = len(testFileList) #测试数据数量
for i in range(mTest):
fileNameStr = testFileList[i] # 第i个文件名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 将文件按'_'分割,获取矩阵实际标签
vectorUnderTest= img2Vector('./testDigits/' + fileNameStr) # 格式转换
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
if classifierResult != classNumStr: errorCount += 1.0
print('\nthe total number of errors is: %d' % errorCount)
print('\nthe rotal error rate is: %f' % (errorCount/float(mTest)))
if __name__ == '__main__':
start = time.clock()
handwritingClassTest()
end = time.clock()
print('Finished in', end - start)
'''输出结果:
...
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the total number of errors is: 10
the rotal error rate is: 0.010571
Finished in 34.70805667447233
'''
通过上述计算结果,我们知道KNN算法识别手写数字数据集的错误率为1.057%,当然可以通过改变k的值、修改随机选取训练样本、改变训练样本的数目,都会对KNN算法的错误率产生影响。实际使用这个算法的执行效率并不高。因为算法需要为每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次,此外,我们要为测试向量准备2MB的存储空间。
通过对KNN算法的使用,我们知道采用KNN算法的优点:精度高、对异常值不敏感、无数据输入假定。缺点:计算复杂度高、空间复杂度高。试用数据范围:数值型和标称型。从计算的过程来看,实现KNN算法,主要是考虑的问题是如何对训练数据进行快速k近邻搜索,这点在特征空间的维数大及训练数据容量大时尤其重要。
KNN最简单的实现方法是线性扫描,这时要计算输入实例与每一个训练实例的距离。当训练集很大时,计算非常耗时,这种方法是不可行的。为了提高k近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。具体可以使用构造平衡kd树(kd tree)方法【具体参考相应的资料,本文不做介绍】。
【参考】:
1. 《机器学习实战》作者:Peter Harrington 第2章 K-近邻算法
2. 《统计学习方法》作者:李航 第3章 K-近邻法
转载声明:
版权声明:非商用自由转载-保持署名-注明出处
署名 :mcyJacky
文章出处:https://blog.csdn.net/mcyJacky