KNN(k-NearestNeighbor)又被称为最近邻算法,它的核心思想是:物以类聚,人以群分。KNN算法是机器学习 中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来 代表。KNN是一种分类算法,KNN没有显式的学习过程,也就是说没有训练阶段,待收到新样本后直接进行处理。
KNN的思路是:如果一个样本在特征空间中的k个最邻近的样本中的大多数属于某一个类别,则该样本也划分为这 个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个 样本的类别来决定待分样本所属的类别。
提到KNN,网上最常见的就是下面这个图,可以帮助大家理解。我们要确定绿点属于哪个颜色(红色或者蓝色), 要做的就是选出距离目标点距离最近的k个点,看这k个点的大多数颜色是什么颜色。当k取3的时候,我们可以看出 距离最近的三个,分别是红色、红色、蓝色,因此判定目标点为红色。
1)分别读取测试数据、训练数据集;
2)计算测试数据与训练数据之间的距离;
3)选取距离最小的K个点;
4) 确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类
/**
* 数据集:鸢尾花数据集
*
* Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。
* 数据集内包含 3 类共 150 条 记录,每类各 50 个数据,
* 每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,
* 可以通过这4个 特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
*
* 数据:
* 6.2,3.4,5.4,2.3,virginica
* 5.9,3,5.1,1.8,virginica
* 5.8,3.4,2.6,2.2
* 5.5,2.3,3.3,1.9
* 5.3,3.7,1.5,0.2,setosa
* 5,3.3,1.4,0.2,setosa
* 7,3.2,4.7,1.4,versicolor
* .........等
*/
package algorithm.MachineLearning
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object SimpleKNN {
def main(args: Array[String]): Unit = {
//1.初始化
val conf=new SparkConf().setAppName("SimpleKnn").setMaster("local[*]")
val sc=new SparkContext(conf)
val K=15
//2.读取数据,封装数据
val data: RDD[LabelPoint] = sc.textFile("file:///H:\\IDEA2019_WorkSpace\\SparkLearning\\src\\main\\data\\iris.csv")
.map(line => {
val arr = line.split(",")
if (arr.length == 5) {
LabelPoint(arr.last, arr.init.map(_.toDouble))
} else {
LabelPoint(" ", arr.map(_.toDouble))
}
})
//3.过滤出样本数据和测试数据
val sampleData=data.filter(_.label!=" ")
val testData=data.filter(_.label==" ").map(_.point).collect()
//4.求每一条测试数据与样本数据的距离
testData.foreach(elem=>{
val distance=sampleData.map(x=>(getDistance(elem,x.point),x.label))
//获取距离最近的k个样本
val minDistance=distance.sortBy(_._1).take(K)
//取出这k个样本的label并且获取出现最多的label即为测试数据的label
val labels=minDistance.map(_._2)
.groupBy(x=>x)
.mapValues(_.length)
.toList
.sortBy(_._2).reverse
.take(1)
.map(_._1)
printf(s"${elem.toBuffer.mkString(",")},${labels.toBuffer.mkString(",")}")
println()
})
sc.stop()
}
case class LabelPoint(label:String,point:Array[Double])
import scala.math._
def getDistance(x:Array[Double],y:Array[Double]):Double={
sqrt(x.zip(y).map(z=>pow(z._1-z._2,2)).sum)
}
}
package algorithm.MachineLearning
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import scala.collection.immutable.TreeSet
object SuperKNN {
def main(args: Array[String]): Unit = {
//1.初始化
val conf=new SparkConf().setAppName("SimpleKnn").setMaster("local[*]")
val sc=new SparkContext(conf)
val K=15
//2.读取数据,封装数据
val data: RDD[LabelPoint] = sc.textFile("file:///H:\\IDEA2019_WorkSpace\\SparkLearning\\src\\main\\data\\iris.csv")
.map(line => {
val arr = line.split(",")
if (arr.length == 5) {
LabelPoint(arr.last, arr.init.map(_.toDouble))
} else {
LabelPoint(" ", arr.map(_.toDouble))
}
})
//3.过滤出样本数据和测试数据
val sampleData=data.filter(_.label!=" ")
val testData=data.filter(_.label==" ").map(_.point).collect()
//4.将testData封装到广播变量做一个优化
val bc_testData=sc.broadcast(testData)
//5.求每一条测试数据与样本数据的距离----使用mapPartitions应对大量数据集进行优化
val distance: RDD[(String,(Double,String))] = sampleData.mapPartitions(iter => {
val bc_points = bc_testData.value
iter.flatMap(x => bc_points.map(point2 => (point2.mkString(","), (getDistance(point2, x.point),x.label))))
})
//6.求距离最小的k个点,使用aggregateByKey---先分局内聚合,再全局聚合
distance.aggregateByKey(TreeSet[(Double,String)]())(
(splitSet:TreeSet[(Double,String)],elem:(Double,String))=>{
val newSet=splitSet+elem //TreeSet默认是有序的(升序)
newSet.take(K)
},
(splitSet1:TreeSet[(Double,String)],splitSet2:TreeSet[(Double,String)])=>{
(splitSet1 ++ splitSet2).take(K)
}
)
//7.取出距离最小的k个点中出现次数最多的label---即为样本数据的label
.map(x=>{
(
x._1,
x._2.toArray.map(_._2).groupBy(y=>y).map(z=>(z._1,z._2.length)).toList.sortBy(_._2).map(_._1).take(1).mkString(",")
)
}).foreach(x=> println(x))
sc.stop()
}
case class LabelPoint(label:String,point:Array[Double])
import scala.math._
def getDistance(x:Array[Double],y:Array[Double]):Double={
sqrt(x.zip(y).map(z=>pow(z._1-z._2,2)).sum)
}
}
(1)理论成熟,思想简单,既可以用来做分类也可以用来做回归;
(2)可用于非线性分类;
(3)训练时间复杂 度低,为O(n);
(4) 和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感;
(1)计算量大,尤其是特征数非常多的时候;
(2)样本不平衡的时候,对稀有类别的预测准 确率低;
(3)kd树,球树之类的模型建立需要大量的内存;
(4)使用懒散学习方法,基本上不学习,导致预测 时速度比起逻辑回归之类的算法慢;
(5)KNN模型可解释性不强。