K-近邻算法之鸢尾花实例 使用Spark实现KNN的Demo

1.1 K-近邻算法(KNN)概念

K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法

定义

如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

来源:KNN算法最早是由Cover和Hart提出的一种分类算法

距离公式
两个样本的距离可以通过如下公式计算,又叫欧式距离
K-近邻算法之鸢尾花实例 使用Spark实现KNN的Demo_第1张图片

1.3 KNN算法流程总结

1)计算已知类别数据集中的点与当前点之间的距离

2)按距离递增次序排序

3)选取与当前点距离最小的k个点

4)统计前k个点所在的类别出现的频率

5)返回前k个点出现频率最高的类别作为当前点的预测分类

自己的步骤:

1、未知点与所有样本点(已知分类的点)的距离
2、基于距离排序
3、取前K个点
4、基于label(分类的信息)做wordcount

2 小结

K-近邻算法简介【了解】
定义:就是通过你的"邻居"来判断你属于哪个类别
如何计算你到你的"邻居"的距离:一般时候,都是使用欧氏距离

案例:鸢尾花种类预测

Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。关于数据集的具体介绍:
K-近邻算法之鸢尾花实例 使用Spark实现KNN的Demo_第2张图片
TODO 优化点:arr.last arr.init
灵活使用数组的.last(取最后一个元素) 与 .init(取除了最后一个元素的所有元素)

package IrisKNN.teacher

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

/**
 * Created by Shi shuai RollerQing on 2019/12/30 16:23
 * 鸢尾花例子
 * KNN 算法
 */
object KNNDemo {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("KNNDemo").setMaster("local[*]")
    val sc = new SparkContext(conf)
    val K = 9
    //1.转换为样例类
    val lines = sc.textFile("C:\\Users\\HP\\IdeaProjects\\sparkCore\\data\\iris.dat")
      .map{line=>
        val fields = line.split(",")
        if (fields.length == 5) LabeledPoint(fields.last, fields.init.map(_.toDouble))
        else LabeledPoint("", fields.map(_.toDouble))//没有label的将label置为空 也换为LabeledPoint样例类
      }
    //2.分类 分为数据集(已知label的)和测试集(未知的)  将没有label的转换为数组即可 方便做外层循环
    val sampleRDD: RDD[LabeledPoint] = lines.filter(_.label != "")
    val testData: Array[Array[Double]] = lines.filter(_.label == "").map(_.point).collect()

    //小表在外面 相当于join 小表在右边
    testData.foreach(point => {
      sampleRDD.map(labeledpoint => {
        (labeledpoint.label, getDistance(point, labeledpoint.point)) //第一个参数是已知点标签 第二个参数是已知点和未知点距离
      }).sortBy(_._2) //排序
        .take(K) //取前9个
        .map{case (label, _) => label} // 然后第二个参数距离就没啥用了 只取label 做WordCount
        .groupBy(x=>x)
        .mapValues(_.length)
        .foreach(print)
    })

    sc.stop()
  }
  import scala.math._
  def getDistance(x: Array[Double], y: Array[Double]): Double = // x y 都是有4个点的数组
    sqrt(x.zip(y).map(elem => pow(elem._1 - elem._2, 2)).sum)
  //使用zip拉链合起来 结果就像是((x1,y1), (x2, y2),(x3, y3),(x4, y4)) 4个元组
  // 每个元组的差的平方 的 和 再sqrt开方 就是距离

}
// 样例类 先存标签 然后四个数用一个数组存起来 因为这四个数合起来才表示一个点 一个坐标
case class LabeledPoint(label: String, point:Array[Double])

思路 贼简单,距离、排序、前K个、基于label做WordCount


预期结果
5.3,3.7,1.5,0.2,setosa
5,3.3,1.4,0.2,setosa
5.1,2.5,3,1.1,versicolor
5.7,2.8,4.1,1.3,versicolor
6.2,3.4,5.4,2.3,virginica
5.9,3,5.1,1.8,virginica
实际结果:
ArrayBuffer(5.3, 3.7, 1.5, 0.2) List((setosa,9))
ArrayBuffer(5.0, 3.3, 1.4, 0.2) List((setosa,9))
ArrayBuffer(5.1, 2.5, 3.0, 1.1) List((versicolor,9))
ArrayBuffer(5.7, 2.8, 4.1, 1.3) List((versicolor,9))
ArrayBuffer(6.2, 3.4, 5.4, 2.3) List((virginica,9))
ArrayBuffer(5.9, 3.0, 5.1, 1.8) List((virginica,7))
结果没错
K-近邻算法之鸢尾花实例 使用Spark实现KNN的Demo_第3张图片

问题:

在重新敲代码遇到了下面的问题
好像是scala的语法? 没学好不太清楚

TODO: 这里注意 使用groupBy(_) 它的返回值不对  
最好还是写成这样吧groupBy(x => x)

 val functionToMap:   (String => Nothing) => Map[Nothing, Array[String]] = strings.groupBy(_)
 val stringToStrings: Map[String, Array[String]]                         = strings.groupBy(x =>x)

K-近邻算法之鸢尾花实例 使用Spark实现KNN的Demo_第4张图片

最开始自己敲得代码

package IrisKNN

import java.util

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}

/**
 * Created by Shi shuai RollerQing on 2019/12/27 14:12
 *
 * 6.5,3,5.2,2,virginica
 * 6.2,3.4,5.4,2.3
 *
 * 数组Array不能作为key 貌似List是可以的
 *
 * 1.数组不能作为key HashPartitioner cannot partition array keys
 * 2.任务不能被序列化
 * 3.RDD不能被嵌套
 */
//1、未知点与所有样本点(已知分类的点)的距离
//2、基于距离排序
//3、取前K个点
//4、基于label(分类的信息)做wordcount
// k = 9
// 距离 (x1 - x2 )^ 2 + (y1 - y2 )^ 2
case class known(x: Double, y: Double, z: Double, w: Double, label: String)
case class unknown(x: Double, y: Double, z: Double, w: Double)
case class disLab(distance: Double, label: String)
object IrisTest {
  def main(args: Array[String]): Unit = {
    val path = "C:\\Users\\HP\\IdeaProjects\\sparkCore\\data\\iris.dat"
    val k: Int = 9
    //SparkSession
    val spark = SparkSession
      .builder
      .appName(IrisTest.getClass.getSimpleName)
      .master("local[*]")
      .getOrCreate()

    val sc: SparkContext = spark.sparkContext
    val data: RDD[String] = sc.textFile(path)

    val arr: RDD[Array[String]] = data.map(line => line.split(","))

    import spark.implicits._

    val knownDataSet: Dataset[known] = arr.filter(_.length > 4).map(t => known(t(0).toDouble, t(1).toDouble,t(2).toDouble,t(3).toDouble,t(4))).toDS
    val unknownDataSet: Dataset[unknown] = arr.filter(_.length <= 4).map(t => unknown(t(0).toDouble, t(1).toDouble, t(2).toDouble, t(3).toDouble)).toDS
    val unknowns: util.List[unknown] = unknownDataSet.collectAsList()


    import scala.collection.JavaConversions._
    for (unknownItem <- unknowns) {
      //1、未知点与所有样本点(已知分类的点)的距离  (((1, x) (2, y) (3, z) (4, w)), label)
      val distanceAndLabel: Dataset[disLab] = knownDataSet.map(known => {
       val ping = math.pow((known.x - unknownItem.x), 2.0) + math.pow((known.y - unknownItem.y), 2.0)
        + math.pow((known.z - unknownItem.z), 2.0) + math.pow((known.w - unknownItem.w), 2.0)
        val label = known.label

        val distance: Double = scala.math.sqrt(ping)
        disLab(distance, label)
      })
      //distanceAndLabel.foreach(rdd => println(rdd))
      //2、基于距离排序
      val sorted: Dataset[disLab] = distanceAndLabel.sort($"distance")
      //3、取前K个点
      val tuples: Array[disLab] = sorted.take(k)
      //4、基于label(分类的信息)做wordcount
      val grouped: Map[String, Array[disLab]] = tuples.groupBy(_.label)
      //grouped.map(x => (x._1, x._2.size))
      val sumed: Map[String, Int] = grouped.mapValues(x => x.size)
      //取最多的标签
      val res = sumed.toList.sortBy(_._2).reverse.take(1).map(_._1).toArray
      //结果就是
      println(unknownItem.toString + "\t" + res(0))
    }


  }
}
//预期结果
//5.3,3.7,1.5,0.2,setosa
//5,3.3,1.4,0.2,setosa
//5.1,2.5,3,1.1,versicolor
//5.7,2.8,4.1,1.3,versicolor
//6.2,3.4,5.4,2.3,virginica
//5.9,3,5.1,1.8,virginica
//实际结果
//unknown(5.3,3.7,1.5,0.2)	setosa
//unknown(5.0,3.3,1.4,0.2)	setosa
//unknown(5.1,2.5,3.0,1.1)	versicolor
//unknown(5.7,2.8,4.1,1.3)	versicolor
//unknown(6.2,3.4,5.4,2.3)	virginica
//unknown(5.9,3.0,5.1,1.8)	versicolor

测试数据

5.1,3.5,1.4,0.2,setosa
4.9,3,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
5,3.6,1.4,0.2,setosa
5.4,3.9,1.7,0.4,setosa
4.6,3.4,1.4,0.3,setosa
5,3.4,1.5,0.2,setosa
4.4,2.9,1.4,0.2,setosa
4.9,3.1,1.5,0.1,setosa
5.4,3.7,1.5,0.2,setosa
4.8,3.4,1.6,0.2,setosa
4.8,3,1.4,0.1,setosa
4.3,3,1.1,0.1,setosa
5.8,4,1.2,0.2,setosa
5.7,4.4,1.5,0.4,setosa
5.4,3.9,1.3,0.4,setosa
5.1,3.5,1.4,0.3,setosa
5.7,3.8,1.7,0.3,setosa
5.1,3.8,1.5,0.3,setosa
5.4,3.4,1.7,0.2,setosa
5.1,3.7,1.5,0.4,setosa
4.6,3.6,1,0.2,setosa
5.1,3.3,1.7,0.5,setosa
4.8,3.4,1.9,0.2,setosa
5,3,1.6,0.2,setosa
5,3.4,1.6,0.4,setosa
5.2,3.5,1.5,0.2,setosa
5.2,3.4,1.4,0.2,setosa
4.7,3.2,1.6,0.2,setosa
4.8,3.1,1.6,0.2,setosa
5.4,3.4,1.5,0.4,setosa
5.2,4.1,1.5,0.1,setosa
5.5,4.2,1.4,0.2,setosa
4.9,3.1,1.5,0.2,setosa
5,3.2,1.2,0.2,setosa
5.5,3.5,1.3,0.2,setosa
4.9,3.6,1.4,0.1,setosa
4.4,3,1.3,0.2,setosa
5.1,3.4,1.5,0.2,setosa
5,3.5,1.3,0.3,setosa
4.5,2.3,1.3,0.3,setosa
4.4,3.2,1.3,0.2,setosa
5,3.5,1.6,0.6,setosa
5.1,3.8,1.9,0.4,setosa
4.8,3,1.4,0.3,setosa
5.1,3.8,1.6,0.2,setosa
4.6,3.2,1.4,0.2,setosa
5.3,3.7,1.5,0.2
5,3.3,1.4,0.2
7,3.2,4.7,1.4,versicolor
6.4,3.2,4.5,1.5,versicolor
6.9,3.1,4.9,1.5,versicolor
5.5,2.3,4,1.3,versicolor
6.5,2.8,4.6,1.5,versicolor
5.7,2.8,4.5,1.3,versicolor
6.3,3.3,4.7,1.6,versicolor
4.9,2.4,3.3,1,versicolor
6.6,2.9,4.6,1.3,versicolor
5.2,2.7,3.9,1.4,versicolor
5,2,3.5,1,versicolor
5.9,3,4.2,1.5,versicolor
6,2.2,4,1,versicolor
6.1,2.9,4.7,1.4,versicolor
5.6,2.9,3.6,1.3,versicolor
6.7,3.1,4.4,1.4,versicolor
5.6,3,4.5,1.5,versicolor
5.8,2.7,4.1,1,versicolor
6.2,2.2,4.5,1.5,versicolor
5.6,2.5,3.9,1.1,versicolor
5.9,3.2,4.8,1.8,versicolor
6.1,2.8,4,1.3,versicolor
6.3,2.5,4.9,1.5,versicolor
6.1,2.8,4.7,1.2,versicolor
6.4,2.9,4.3,1.3,versicolor
6.6,3,4.4,1.4,versicolor
6.8,2.8,4.8,1.4,versicolor
6.7,3,5,1.7,versicolor
6,2.9,4.5,1.5,versicolor
5.7,2.6,3.5,1,versicolor
5.5,2.4,3.8,1.1,versicolor
5.5,2.4,3.7,1,versicolor
5.8,2.7,3.9,1.2,versicolor
6,2.7,5.1,1.6,versicolor
5.4,3,4.5,1.5,versicolor
6,3.4,4.5,1.6,versicolor
6.7,3.1,4.7,1.5,versicolor
6.3,2.3,4.4,1.3,versicolor
5.6,3,4.1,1.3,versicolor
5.5,2.5,4,1.3,versicolor
5.5,2.6,4.4,1.2,versicolor
6.1,3,4.6,1.4,versicolor
5.8,2.6,4,1.2,versicolor
5,2.3,3.3,1,versicolor
5.6,2.7,4.2,1.3,versicolor
5.7,3,4.2,1.2,versicolor
5.7,2.9,4.2,1.3,versicolor
6.2,2.9,4.3,1.3,versicolor
5.1,2.5,3,1.1
5.7,2.8,4.1,1.3
6.3,3.3,6,2.5,virginica
5.8,2.7,5.1,1.9,virginica
7.1,3,5.9,2.1,virginica
6.3,2.9,5.6,1.8,virginica
6.5,3,5.8,2.2,virginica
7.6,3,6.6,2.1,virginica
4.9,2.5,4.5,1.7,virginica
7.3,2.9,6.3,1.8,virginica
6.7,2.5,5.8,1.8,virginica
7.2,3.6,6.1,2.5,virginica
6.5,3.2,5.1,2,virginica
6.4,2.7,5.3,1.9,virginica
6.8,3,5.5,2.1,virginica
5.7,2.5,5,2,virginica
5.8,2.8,5.1,2.4,virginica
6.4,3.2,5.3,2.3,virginica
6.5,3,5.5,1.8,virginica
7.7,3.8,6.7,2.2,virginica
7.7,2.6,6.9,2.3,virginica
6,2.2,5,1.5,virginica
6.9,3.2,5.7,2.3,virginica
5.6,2.8,4.9,2,virginica
7.7,2.8,6.7,2,virginica
6.3,2.7,4.9,1.8,virginica
6.7,3.3,5.7,2.1,virginica
7.2,3.2,6,1.8,virginica
6.2,2.8,4.8,1.8,virginica
6.1,3,4.9,1.8,virginica
6.4,2.8,5.6,2.1,virginica
7.2,3,5.8,1.6,virginica
7.4,2.8,6.1,1.9,virginica
7.9,3.8,6.4,2,virginica
6.4,2.8,5.6,2.2,virginica
6.3,2.8,5.1,1.5,virginica
6.1,2.6,5.6,1.4,virginica
7.7,3,6.1,2.3,virginica
6.3,3.4,5.6,2.4,virginica
6.4,3.1,5.5,1.8,virginica
6,3,4.8,1.8,virginica
6.9,3.1,5.4,2.1,virginica
6.7,3.1,5.6,2.4,virginica
6.9,3.1,5.1,2.3,virginica
5.8,2.7,5.1,1.9,virginica
6.8,3.2,5.9,2.3,virginica
6.7,3.3,5.7,2.5,virginica
6.7,3,5.2,2.3,virginica
6.3,2.5,5,1.9,virginica
6.5,3,5.2,2,virginica
6.2,3.4,5.4,2.3
5.9,3,5.1,1.8

你可能感兴趣的:(spark,KNN)