Spark:Scala实现KMeans算法

1 什么是KMeans算法
K-Means算法是一种cluster analysis的算法,其主要是来计算数据聚集的算法,主要通过不断地取离种子点最近均值的算法。
具体来说,通过输入聚类个数k,以及包含 n个数据对象的数据库,输出满足方差最小标准的k个聚类。

2 k-means 算法基本步骤
(1) 从 n个数据对象任意选择 k 个对象作为初始聚类中心; 
(2) 根据每个聚类对象的均值(中心对象),计算每个对象与这些中心对象的距离;并根据最小距离重新对相应对象进行划分;
(3) 重新计算每个(有变化)聚类的均值(中心对象);
(4) 计算标准测度函数,当满足一定条件,如函数收敛时,则算法终止;如果条件不满足则回到步骤(2)。
算法的时间复杂度上界为O(n*k*t), 其中t是迭代次数,n个数据对象划分为 k个聚类。


3 MLlib实现KMeans
以MLlib实现KMeans算法,利用MLlib KMeans训练出来的模型,可以对新的数据作出分类预测,具体见代码和输出结果。
Scala代码:

  1. 1 package com.hq

  2. 3 import org.apache.spark.mllib.clustering.KMeans
  3. 4 import org.apache.spark.mllib.linalg.Vectors
  4. 5 import org.apache.spark.{SparkContext, SparkConf}

  5. 7 object KMeansTest {
  6. 8   def main(args: Array[String]) {
  7. 9     if (args.length < 1) {
  8. 10       System.err.println("Usage: ")
  9. 11       System.exit(1)
  10. 12     }
  11. 13 
  12. 14     val conf = new SparkConf()
  13. 15     val sc = new SparkContext(conf)
  14. 16     val data = sc.textFile(args(0))
  15. 17     val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
  16. 18     val numClusters = 2
  17. 19     val numIterations = 20
  18. 20     val clusters = KMeans.train(parsedData,numClusters,numIterations)
  19. 21 
  20. 22     println("------Predict the existing line in the analyzed data file: "+args(0))
  21. 23     println("Vector 1.0 2.1 3.8 belongs to clustering "+ clusters.predict(Vectors.dense("1.0 2.1 3.8".split(' ').map(_.toDouble))))
  22. 24     println("Vector 5.6 7.6 8.9 belongs to clustering "+ clusters.predict(Vectors.dense("5.6 7.6 8.9".split(' ').map(_.toDouble))))
  23. 25     println("Vector 3.2 3.3 6.6 belongs to clustering "+ clusters.predict(Vectors.dense("3.2 3.3 6.6".split(' ').map(_.toDouble))))
  24. 26     println("Vector 8.1 9.2 9.3 belongs to clustering "+ clusters.predict(Vectors.dense("8.1 9.2 9.3".split(' ').map(_.toDouble))))
  25. 27     println("Vector 6.2 6.5 7.3 belongs to clustering "+ clusters.predict(Vectors.dense("6.2 6.5 7.3".split(' ').map(_.toDouble))))
  26. 28 
  27. 29     println("-------Predict the non-existent line in the analyzed data file: ----------------")
  28. 30     println("Vector 1.1 2.2 3.9  belongs to clustering "+ clusters.predict(Vectors.dense("1.1 2.2 3.9".split(' ').map(_.toDouble))))
  29. 31     println("Vector 5.5 7.5 8.8  belongs to clustering "+ clusters.predict(Vectors.dense("5.5 7.5 8.8".split(' ').map(_.toDouble))))
  30. 32 
  31. 33     println("-------Evaluate clustering by computing Within Set Sum of Squared Errors:-----")
  32. 34     val wssse = clusters.computeCost(parsedData)
  33. 35     println("Within Set Sum of Squared Errors = "+ wssse)
  34. 36     sc.stop()
  35. 37   }
  36. 38 }
复制代码


4 以Spark集群standalone方式运行
①在IDEA打成jar包(如果忘记了,参见 Spark:用Scala和Java实现WordCount ),上传到用户目录下/home/ebupt/test/kmeans.jar
②准备训练样本数据:hdfs://eb170:8020/user/ebupt/kmeansData,内容如下
[ebupt@eb170 ~]$ hadoop fs -cat ./kmeansData
  1. 1.0 2.1 3.8
  2. 5.6 7.6 8.9
  3. 3.2 3.3 6.6
  4. 8.1 9.2 9.3
  5. 6.2 6.5 7.3
复制代码



③spark-submit提交运行
[ebupt@eb174 test]$  spark -submit --master spark://eb174:7077 --name KmeansWithMLib --class com.hq.KMeansTest --executor-memory 2G --total-executor-cores 4 ~/test/kmeans.jar hdfs://eb170:8020/user/ebupt/kmeansData

输出结果摘要:
  1. 1 ------Predict the existing line in the analyzed data file: hdfs://eb170:8020/user/ebupt/kmeansData
  2. 2 Vector 1.0 2.1 3.8 belongs to clustering 0
  3. 3 Vector 5.6 7.6 8.9 belongs to clustering 1
  4. 4 Vector 3.2 3.3 6.6 belongs to clustering 0
  5. 5 Vector 8.1 9.2 9.3 belongs to clustering 1
  6. 6 Vector 6.2 6.5 7.3 belongs to clustering 1
  7. 7 -------Predict the non-existent line in the analyzed data file: ----------------
  8. 8 Vector 1.1 2.2 3.9  belongs to clustering 0
  9. 9 Vector 5.5 7.5 8.8  belongs to clustering 1
  10. 10 -------Evaluate clustering by computing Within Set Sum of Squared Errors:-----
  11. 11 Within Set Sum of Squared Errors = 16.393333333333388
复制代码


5 Spark总结
本文主要介绍了MLbase如何实现机器学习算法,简单介绍了MLBase的设计思想。
与其它 机器学习 系统Weka、mahout不同:
  • MLbase是分布式的,Weka是单机的。
  • Mlbase是自动化的,Weka和mahout都需要使用者具备机器学习技能,来选择自己想要的算法和参数来做处理。
  • MLbase提供了不同抽象程度的接口,可以扩充ML算法。

参考文献:http://www.aboutyun.com/thread-10817-1-1.html

你可能感兴趣的:(Spark)