k-means算法

原理

k-means算法_第1张图片
timg.jpg

(a)背景:假定在二维空间中有一些待分类的样本,需要将这些样本点分为2类。
(b)随机选择n个点(n=类别数2),作为第一轮的分类中心点
(c)计算每个待测样本点与两个中心点的距离,将其归类到较近的那个类
(d)在上一轮分类得到的样本中,分别取两个类样本的中心作为新一轮的中心点。
(e)重复c,d,直到中心点不再变化。

算法实现

基于spark-mllib

数据来源:

数据源:某批发经销商的客户在各种类别产品上的年消费数
来自UCI Machine Learning Repository
http://archive.ics.uci.edu/ml/datasets/Wholesale+customers

k-means算法_第2张图片
image.png

代码实现

package Cluster

import org.apache.spark.mllib.clustering.{KMeansModel,KMeans}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkContext, SparkConf}

/**
  * Created by dengyu on 17/6/2.
  */
object KmeansTest {
  private def isColumnNameLine(line:String):Boolean = {
    if(line!=null && line.contains("Channel")) true else false
  }

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("K-means").setMaster("local")
    val sc = new SparkContext(conf)
    //读取源数据
    val rawData = sc.textFile("file:///Users/dengyu/Downloads/a.csv")
    //将源数据分为训练数据和测试数据
    val splitData = rawData.randomSplit(Array(0.7,0.3))
    val trainData = splitData(0)
    val testData = splitData(1)

    //处理训练数据,获得训练矩阵
    val train = trainData.filter(line => !isColumnNameLine(line)).map(
      line =>
        {Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))}
    )
    //设置聚类的类别数量 
    val numClusters = 8
    //设置迭代次数
    val numIterations = 39
    //设置run的次数
    val runTimes = 3
    var clusterIndex : Int = 0
    //建模
    val clusters:KMeansModel = KMeans.train(train,numClusters,numIterations,runTimes)
    println("Cluster Number:"+clusters.clusterCenters.length)
    println("Cluster Centers Information Overview:")
    //输出聚类中心点的坐标
    clusters.clusterCenters.foreach(x=>{
      print("Center Point of Cluster:"+clusterIndex+":")
      println(x)
      clusterIndex += 1
    })
    //处理测试数据,获取测试矩阵
    val test = testData.filter(line => !isColumnNameLine(line)).map(
      line =>
      {Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))}
    )
    //用聚类模型预测测试集,返回所在的聚类中心
    test.collect().foreach(testLine =>
    {
      val predictedClusterIndex:Int = clusters.predict(testLine)
      println("The data "+testLine.toString+" belongs to cluster "+ predictedClusterIndex)
    })

  }

}

结果

8个类及其类别中心点:

Cluster Number:8
Cluster Centers Information Overview:
Center Point of Cluster:0:[1.0793650793650793,2.515873015873016,5078.166666666666,2918.9841269841268,3263.6666666666665,2610.4841269841268,881.6428571428571,896.015873015873]
Center Point of Cluster:1:[2.0,3.0,40204.0,46314.0,57584.5,5518.0,25436.0,4241.0]
Center Point of Cluster:2:[1.1363636363636365,2.6363636363636362,40838.5,5504.227272727273,6354.772727272727,5843.590909090909,943.909090909091,2363.909090909091]
Center Point of Cluster:3:[1.830188679245283,2.5283018867924527,3388.433962264151,9543.32075471698,15430.358490566037,1309.0566037735848,6567.11320754717,1376.1509433962265]
Center Point of Cluster:4:[1.6486486486486487,2.621621621621622,13174.270270270272,6613.594594594595,9301.72972972973,1761.3513513513515,3120.0810810810813,1895.945945945946]
Center Point of Cluster:5:[2.0,3.0,22925.0,73498.0,32114.0,987.0,20070.0,903.0]
Center Point of Cluster:6:[1.0597014925373134,2.582089552238806,18936.17910447761,2617.3283582089553,3120.7014925373132,4561.985074626866,517.7164179104477,1404.5223880597014]
Center Point of Cluster:7:[2.0,2.2666666666666666,8766.533333333333,18596.266666666666,32031.933333333334,2014.8,15845.866666666667,2325.2666666666664]

测试集预测结果:

The data [2.0,3.0,9413.0,8259.0,5126.0,666.0,1795.0,1451.0] belongs to cluster 4
The data [2.0,3.0,12126.0,3199.0,6975.0,480.0,3140.0,545.0] belongs to cluster 4
The data [1.0,3.0,5567.0,871.0,2010.0,3383.0,375.0,569.0] belongs to cluster 0
The data [1.0,3.0,31276.0,1917.0,4469.0,9408.0,2381.0,4334.0] belongs to cluster 2
The data [2.0,3.0,26373.0,36423.0,22019.0,5154.0,4337.0,16523.0] belongs to cluster 7
The data [2.0,3.0,16165.0,4230.0,7595.0,201.0,4003.0,57.0] belongs to cluster 4
The data [1.0,3.0,29729.0,4786.0,7326.0,6130.0,361.0,1083.0] belongs to cluster 2
The data [1.0,3.0,1502.0,1979.0,2262.0,425.0,483.0,395.0] belongs to cluster 0
The data [1.0,3.0,56159.0,555.0,902.0,10002.0,212.0,2916.0] belongs to cluster 2
The data [2.0,3.0,10850.0,7555.0,14961.0,188.0,6899.0,46.0] belongs to cluster 4
The data [2.0,3.0,630.0,11095.0,23998.0,787.0,9529.0,72.0] belongs to cluster 3
The data [2.0,3.0,9670.0,7027.0,10471.0,541.0,4618.0,65.0] belongs to cluster 4
The data [2.0,3.0,5417.0,9933.0,10487.0,38.0,7572.0,1282.0] belongs to cluster 3
The data [1.0,3.0,13779.0,1970.0,1648.0,596.0,227.0,436.0] belongs to cluster 6
The data [1.0,3.0,6137.0,5360.0,8040.0,129.0,3084.0,1603.0] belongs to cluster 0
The data [2.0,3.0,7823.0,6245.0,6544.0,4154.0,4074.0,964.0] belongs to cluster 0
The data [2.0,3.0,85.0,20959.0,45828.0,36.0,24231.0,1423.0] belongs to cluster 7
.....

你可能感兴趣的:(k-means算法)