机器学习系列--KNN分类算法例子

url:机器学习系列--KNN分类算法

用的是spark2.0.2,scala2.11

import org.apache.spark.{SparkConf, SparkContext}

object knntest {

  /**
    * 欧式距离
    * 计算两点间的距离
    * @param rs as r1,r2, ..., rd
    * @param ss as s1,s2, ..., sd
    * @param d 维数
    */
  def euclideanDistance(rs: String, ss: String, d: Int): Double = {
    val r = rs.split(",").map(_.toDouble)
    val s = ss.split(",").map(_.toDouble)

    if (r.length != d || s.length != d) Double.NaN else {
      //zip匹配key/value 分区数一样,ri-si的平方的求和再开方,欧式距离
      math.sqrt((r, s).zipped.take(d).map {
        case (ri, si) => math.pow(ri - si, 2)
      }.sum)
    }
  }

  def main(args: Array[String]): Unit = {
    val sparkConf=new SparkConf().setAppName("knntest").setMaster("local[4]")
    val sc=new SparkContext(sparkConf)

    //生成矩阵,每行代表一个样本 10为索引,A,B为类别,其它为属性1,2..
    val groupes=sc.parallelize(List("10;A;1.0,0.9", "11;A;1.0,1.0", "12;B;0.1,0.2", "13;B;0.0,0.1"))
    //100为索引,其它为属性1,2..
    val testxs = sc.parallelize(List("100;1.2,1.0","101;0.1,0.3"))
    //近邻数
    val k = sc.broadcast(3)
    //向量维度
    val d = sc.broadcast(2)
    //笛卡尔
    //ArrayBuffer((100;1.2,1.0,10;A;1.0,0.9),
    // (100;1.2,1.0,11;A;1.0,1.0),
    // (100;1.2,1.0,12;B;0.1,0.2),
    // (100;1.2,1.0,13;B;0.0,0.1),
    // (101;0.1,0.3,10;A;1.0,0.9),
    // (101;0.1,0.3,11;A;1.0,1.0),
    // (101;0.1,0.3,12;B;0.1,0.2),
    // (101;0.1,0.3,13;B;0.0,0.1))
    val cart=testxs.cartesian(groupes)

    val knns=cart.map(p=>{
      val testx=p._1//例 100;1.2,1.0
      val group2=p._2//例 10;A;1.0,0.9
      val testx_index=testx.split(";")(0)
      val testx_rs=testx.split(";")(1)

      //类型
      val group2_type=group2.split(";")(1)
      val group2_ss=group2.split(";")(2)
      //欧式距离
      val distance =euclideanDistance(testx_rs, group2_ss, d.value)
      //ArrayBuffer((100,(0.2236067977499789,A)), (100,(0.19999999999999996,A)),
      // (100,(1.3601470508735443,B)), (100,(1.5,B)),
      // (101,(1.0816653826391969,A)), (101,(1.140175425099138,A)),
      // (101,(0.09999999999999998,B)), (101,(0.22360679774997896,B)))
      (testx_index,(distance,group2_type))
    })

    val knnGrouped = knns.groupByKey()

    val knnOutput = knnGrouped.mapValues(itr => {
      //(100,List((0.19999999999999996,A), (0.2236067977499789,A), (1.3601470508735443,B)))
      //(101,List((0.09999999999999998,B), (0.22360679774997896,B), (1.0816653826391969,A)))
      val nearestK = itr.toList.sortBy(_._1).take(k.value)
      //(101,List((B,1), (B,1), (A,1)))
      //(100,List((A,1), (A,1), (B,1)))
      //(100,Map(A -> 2, B -> 1))
      //(101,Map(A -> 1, B -> 2))
      val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => {
        val (stringList, intlist) = list.unzip
        intlist.sum
      })
      //(100,A)
      //(101,B)
      majority.maxBy(_._2)._1
    })

    knnOutput.foreach(println)
    sc.stop()
  }
}

你可能感兴趣的:(数据挖掘,机器学习)