scala实现kmeans算法

算法的概念不做过都解释,google一下一大把。直接贴上代码,有比较详细的注释了。

主程序:

 1 import scala.io.Source

 2 import scala.util.Random

 3 

 4 /**

 5  * @author vincent

 6  *

 7  */

 8 object LocalKMeans {

 9     def main(args: Array[String]) {

10         val fileName = "/home/vincent/kmeans_data.txt"

11         val knumbers = 3

12         val rand = new Random()

13 

14         //  读取文本数据

15         val lines = Source.fromFile(fileName).getLines.toArray

16         val points = lines.map(line => {

17             val parts = line.split("\t").map(_.toDouble)

18             new Point(parts(0), parts(1))

19         }).toArray

20         

21         //  随机初始化k个质心

22         val centroids = new Array[Point](knumbers)

23         for (i <- 0 until knumbers) {

24             centroids(i) = points(new Random().nextInt(points.length))

25         }

26         val startTime = System.currentTimeMillis()

27         println("initialize centroids:\n" + centroids.mkString("\n") + "\n")

28         println("test points: \n" + points.mkString("\n") + "\n")

29 

30         val resultCentroids = kmeans(points, centroids, 0.001)

31         

32         val endTime = System.currentTimeMillis()

33         val runTime = endTime - startTime

34         println("run Time: " + runTime + "\nFinal centroids: \n" + resultCentroids.mkString("\n"))

35     }

36     

37     //  算法的核心函数

38     def kmeans(points: Seq[Point], centroids: Seq[Point], epsilon: Double): Seq[Point] = {

39         //  最近质心为key值,将数据集分簇

40         val clusters = points.groupBy(closestCentroid(centroids, _))

41         println("clusters: \n" + clusters.mkString("\n") + "\n")

42         //  分别计算簇中数据集的平均数,得到每个簇的新质心

43         val newCentroids = centroids.map(oldCentroid => {

44             clusters.get(oldCentroid) match {

45                 case Some(pointsInCluster) => pointsInCluster.reduceLeft(_ + _) / pointsInCluster.length

46                 case None => oldCentroid

47             }

48         })

49         //  计算新质心相对与旧质心的偏移量

50         val movement = (centroids zip newCentroids).map({ case (a, b) => a distance b })

51         println("Centroids changed by\n" + movement.map(d => "%3f".format(d)).mkString("(", ", ", ")")

52             + "\nto\n" + newCentroids.mkString(", ") + "\n")

53         //  根据偏移值大小决定是否继续迭代,epsilon为最小偏移值

54         if (movement.exists(_ > epsilon))

55             kmeans(points, newCentroids, epsilon)

56         else

57             return newCentroids

58     }

59 

60     //  计算最近质心

61     def closestCentroid(centroids: Seq[Point], point: Point) = {

62         centroids.reduceLeft((a, b) => if ((point distance a) < (point distance b)) a else b)

63     }

64 }

 

自定义Point类:

 1 /**

 2  * @author vincent

 3  *

 4  */

 5 object Point {

 6     def random() = {

 7         new Point(math.random * 50, math.random * 50)

 8     }

 9 }

10 

11 case class Point(val x: Double, val y: Double) {

12     def +(that: Point) = new Point(this.x + that.x, this.y + that.y)

13     def -(that: Point) = new Point(this.x - that.x, this.y - that.y)

14     def /(d: Double) = new Point(this.x / d, this.y / d)

15     def pointLength = math.sqrt(x * x + y * y)

16     def distance(that: Point) = (this - that).pointLength

17     override def toString = format("(%.3f, %.3f)", x, y)

18 }

测试数据集:

12.044996    36.412378

31.881257    33.677009

41.703139    46.170517

43.244406    6.991669

19.319000    27.926669

3.556824    40.935215

29.328655    33.303675

43.702858    22.305344

28.978940    28.905725

10.426760    40.311507

 

 

你可能感兴趣的:(scala)