Spark MLlib 入门学习笔记 - KMeans聚类

MLlib KMeans的使用说明详见文档。

def train(data: RDD[Vector], k: Int, maxIterations: Int, runs: Int, initializationMode: String, seed: Long): KMeansModel
Trains a k-means model using the given set of parameters.
data Training points as an RDD of Vector types.
k Number of clusters to create.
maxIterations Maximum number of iterations allowed.
runs This param has no effect since Spark 2.0.0.
initializationMode The initialization algorithm. This can either be "random" or "k-means||". (default: "k-means||")
seed Random seed for cluster initialization. Default is to generate seed based on system time.

Kmeans的一个主要问题是K的选择。Spark MLlib 在 KMeansModel 类里提供了 computeCost 方法,该方法通过计算所有数据点到其最近的中心点的平方和来评估聚类的效果。一般来说,同样的迭代次数和算法跑的次数,这个值越小代表聚类的效果越好。但在实际情况下,要考虑到聚类结果的可解释性,或者说是经验。

测试数据集,Wholesale customer数据集,数据共有440行,为某批发经销商的客户在各种类别产品上的年消费数。



package cluster

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.{Vector, Vectors}


object kmeans {

  def isValid(line: String): Boolean = {
    if(line == null)
      return false
    else if(line.contains("Channel"))
      return false
    else {
      val parts = line.split(",")
      return parts.length == 8

  def parseLine(line: String): Vector = {
    val parts = line.split(",")
    val vd: Vector = Vectors.dense(parts(0).toDouble, parts(1).toDouble, parts(2).toDouble, parts(3).toDouble,
      parts(4).toDouble, parts(5).toDouble, parts(6).toDouble, parts(7).toDouble)

    return vd

  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster(args(0)).setAppName("kmeans")
    val sc = new SparkContext(conf)
    val data = sc.textFile(args(1)).filter(isValid(_)).map(parseLine(_))

    val splits = data.randomSplit(Array(0.7, 0.3), seed=11L)
    val trainData = splits(0)
    val testData = splits(1)

    val ks: Array[Int] = Array(3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20)
    ks.foreach(cluster => {
      val model: KMeansModel = KMeans.train(trainData, cluster,30,1)
      val ssd = model.computeCost(trainData)
      println("sum of squared distances of points to their nearest center when k=" + cluster + " -> "+ ssd)



sum of squared distances of points to their nearest center when k=3 -> 5.546127099037659E10
sum of squared distances of points to their nearest center when k=4 -> 4.469670002994815E10
sum of squared distances of points to their nearest center when k=5 -> 4.33608510543818E10
sum of squared distances of points to their nearest center when k=6 -> 3.2295351497875282E10
sum of squared distances of points to their nearest center when k=7 -> 3.0094488311840385E10
sum of squared distances of points to their nearest center when k=8 -> 2.9436705297706715E10
sum of squared distances of points to their nearest center when k=9 -> 2.5272091701191124E10
sum of squared distances of points to their nearest center when k=10 -> 2.5365820093186584E10
sum of squared distances of points to their nearest center when k=11 -> 1.8937065089341183E10
sum of squared distances of points to their nearest center when k=12 -> 1.748844774446142E10
sum of squared distances of points to their nearest center when k=13 -> 1.5781295960364405E10
sum of squared distances of points to their nearest center when k=14 -> 1.560378405779067E10
sum of squared distances of points to their nearest center when k=15 -> 1.4255237101885075E10
sum of squared distances of points to their nearest center when k=16 -> 1.425421947412815E10
sum of squared distances of points to their nearest center when k=17 -> 1.2284983636269205E10
sum of squared distances of points to their nearest center when k=18 -> 1.096353773020586E10
sum of squared distances of points to their nearest center when k=19 -> 1.0327953117434975E10
sum of squared distances of points to their nearest center when k=20 -> 1.013344408378622E10


package cluster

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.{Vector, Vectors}


object kmeans {

  def isValid(line: String): Boolean = {
    if(line == null)
      return false
    else if(line.contains("Channel"))
      return false
    else {
      val parts = line.split(",")
      return parts.length == 8

  def parseLine(line: String): Vector = {
    val parts = line.split(",")
    val vd: Vector = Vectors.dense(parts(0).toDouble, parts(1).toDouble, parts(2).toDouble, parts(3).toDouble,
      parts(4).toDouble, parts(5).toDouble, parts(6).toDouble, parts(7).toDouble)

    return vd

  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster(args(0)).setAppName("kmeans")
    val sc = new SparkContext(conf)
    val data = sc.textFile(args(1)).filter(isValid(_)).map(parseLine(_))

    val splits = data.randomSplit(Array(0.7, 0.3), seed=11L)
    val trainData = splits(0)
    val testData = splits(1)

    //val ks: Array[Int] = Array(3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20)
    //ks.foreach(cluster => {
    //  val model: KMeansModel = KMeans.train(trainData, cluster,30,1)
    //  val ssd = model.computeCost(trainData)
    //  println("sum of squared distances of points to their nearest center when k=" + cluster + " -> "+ ssd)

    var clusterIndex : Int = 0
    val model : KMeansModel = KMeans.train(trainData, 8, 30, 1)
    println("Cluster Number:" + model.clusterCenters.length)
    println("Cluster Centers Information Overview:")

    model.clusterCenters.foreach( x => {
      println("Center Point of Cluster " + clusterIndex + ":")
      clusterIndex += 1

    testData.collect().foreach(testDataLine => {
      val predictedClusterIndex:
      Int = model.predict(testDataLine)
      println("The data " + testDataLine.toString + " belongs to cluster " +



Cluster Number:8
Cluster Centers Information Overview:
Center Point of Cluster 0:
Center Point of Cluster 1:
Center Point of Cluster 2:
Center Point of Cluster 3:
Center Point of Cluster 4:
Center Point of Cluster 5:
Center Point of Cluster 6:
Center Point of Cluster 7:
The data [2.0,3.0,6353.0,8808.0,7684.0,2405.0,3516.0,7844.0] belongs to cluster 6
The data [1.0,3.0,13265.0,1196.0,4221.0,6404.0,507.0,1788.0] belongs to cluster 4
The data [2.0,3.0,6006.0,11093.0,18881.0,1159.0,7425.0,2098.0] belongs to cluster 7
The data [2.0,3.0,3366.0,5403.0,12974.0,4400.0,5977.0,1744.0] belongs to cluster 6
The data [2.0,3.0,13146.0,1124.0,4523.0,1420.0,549.0,497.0] belongs to cluster 4
The data [2.0,3.0,31714.0,12319.0,11757.0,287.0,3881.0,2931.0] belongs to cluster 2
The data [1.0,3.0,5567.0,871.0,2010.0,3383.0,375.0,569.0] belongs to cluster 0
The data [1.0,3.0,31276.0,1917.0,4469.0,9408.0,2381.0,4334.0] belongs to cluster 2
The data [2.0,3.0,26373.0,36423.0,22019.0,5154.0,4337.0,16523.0] belongs to cluster 3
The data [1.0,3.0,29729.0,4786.0,7326.0,6130.0,361.0,1083.0] belongs to cluster 2
The data [1.0,3.0,56159.0,555.0,902.0,10002.0,212.0,2916.0] belongs to cluster 2
The data [1.0,3.0,19176.0,3065.0,5956.0,2033.0,2575.0,2802.0] belongs to cluster 4
The data [2.0,3.0,10850.0,7555.0,14961.0,188.0,6899.0,46.0] belongs to cluster 6
The data [2.0,3.0,44466.0,54259.0,55571.0,7782.0,24171.0,6465.0] belongs to cluster 3
The data [2.0,3.0,4967.0,21412.0,28921.0,1798.0,13583.0,1163.0] belongs to cluster 7
The data [1.0,3.0,3347.0,4051.0,6996.0,239.0,1538.0,301.0] belongs to cluster 0
The data [1.0,3.0,6137.0,5360.0,8040.0,129.0,3084.0,1603.0] belongs to cluster 6
The data [2.0,3.0,7823.0,6245.0,6544.0,4154.0,4074.0,964.0] belongs to cluster 6
The data [1.0,3.0,9.0,1534.0,7417.0,175.0,3468.0,27.0] belongs to cluster 6
The data [1.0,3.0,2446.0,7260.0,3993.0,5870.0,788.0,3095.0] belongs to cluster 0
The data [2.0,3.0,19899.0,5332.0,8713.0,8132.0,764.0,648.0] belongs to cluster 4
The data [1.0,3.0,1640.0,3259.0,3655.0,868.0,1202.0,1653.0] belongs to cluster 0
The data [2.0,3.0,219.0,9540.0,14403.0,283.0,7818.0,156.0] belongs to cluster 6
The data [2.0,3.0,16117.0,46197.0,92780.0,1026.0,40827.0,2944.0] belongs to cluster 3
The data [1.0,3.0,43265.0,5025.0,8117.0,6312.0,1579.0,14351.0] belongs to cluster 2
The data [1.0,3.0,24904.0,3836.0,5330.0,3443.0,454.0,3178.0] belongs to cluster 4
The data [1.0,3.0,11405.0,596.0,1638.0,3347.0,69.0,360.0] belongs to cluster 0
The data [2.0,3.0,1420.0,10810.0,16267.0,1593.0,6766.0,1838.0] belongs to cluster 7
The data [1.0,3.0,15587.0,1014.0,3970.0,910.0,139.0,1378.0] belongs to cluster 4
The data [2.0,3.0,8797.0,10646.0,14886.0,2471.0,8969.0,1438.0] belongs to cluster 7
The data [2.0,3.0,1531.0,8397.0,6981.0,247.0,2505.0,1236.0] belongs to cluster 6
The data [1.0,3.0,18044.0,1080.0,2000.0,2555.0,118.0,1266.0] belongs to cluster 4
The data [1.0,3.0,20049.0,1891.0,2362.0,5343.0,411.0,933.0] belongs to cluster 4
The data [1.0,3.0,17160.0,1200.0,3412.0,2417.0,174.0,1136.0] belongs to cluster 4
The data [1.0,3.0,4020.0,3234.0,1498.0,2395.0,264.0,255.0] belongs to cluster 0
The data [1.0,3.0,12212.0,201.0,245.0,1991.0,25.0,860.0] belongs to cluster 0
The data [1.0,3.0,19219.0,1840.0,1658.0,8195.0,349.0,483.0] belongs to cluster 4
The data [1.0,3.0,42312.0,926.0,1510.0,1718.0,410.0,1819.0] belongs to cluster 2
The data [1.0,3.0,13537.0,4257.0,5034.0,155.0,249.0,3271.0] belongs to cluster 4
The data [1.0,3.0,17623.0,4280.0,7305.0,2279.0,960.0,2616.0] belongs to cluster 4
The data [1.0,3.0,9203.0,3373.0,2707.0,1286.0,1082.0,526.0] belongs to cluster 0
The data [1.0,3.0,16225.0,1825.0,1765.0,853.0,170.0,1067.0] belongs to cluster 4
The data [1.0,3.0,17773.0,1366.0,2474.0,3378.0,811.0,418.0] belongs to cluster 4
The data [2.0,3.0,2861.0,6570.0,9618.0,930.0,4004.0,1682.0] belongs to cluster 6
The data [1.0,3.0,15177.0,2024.0,3810.0,2665.0,232.0,610.0] belongs to cluster 4
The data [2.0,3.0,5531.0,15726.0,26870.0,2367.0,13726.0,446.0] belongs to cluster 7
The data [2.0,3.0,4822.0,6721.0,9170.0,993.0,4973.0,3637.0] belongs to cluster 6
The data [1.0,3.0,5414.0,717.0,2155.0,2399.0,69.0,750.0] belongs to cluster 0
The data [2.0,3.0,260.0,8675.0,13430.0,1116.0,7015.0,323.0] belongs to cluster 6
The data [2.0,3.0,200.0,25862.0,19816.0,651.0,8773.0,6250.0] belongs to cluster 7
The data [1.0,3.0,45640.0,6958.0,6536.0,7368.0,1532.0,230.0] belongs to cluster 2
The data [1.0,3.0,12759.0,7330.0,4533.0,1752.0,20.0,2631.0] belongs to cluster 4
The data [1.0,3.0,2438.0,8002.0,9819.0,6269.0,3459.0,3.0] belongs to cluster 6
The data [2.0,3.0,8040.0,7639.0,11687.0,2758.0,6839.0,404.0] belongs to cluster 6
The data [1.0,3.0,5509.0,1461.0,2251.0,547.0,187.0,409.0] belongs to cluster 0
The data [2.0,3.0,180.0,3485.0,20292.0,959.0,5618.0,666.0] belongs to cluster 7
The data [1.0,3.0,17023.0,5139.0,5230.0,7888.0,330.0,1755.0] belongs to cluster 4
The data [1.0,1.0,11686.0,2154.0,6824.0,3527.0,592.0,697.0] belongs to cluster 0
The data [2.0,1.0,4484.0,14399.0,24708.0,3549.0,14235.0,1681.0] belongs to cluster 7
The data [2.0,1.0,1107.0,11711.0,23596.0,955.0,9265.0,710.0] belongs to cluster 7
The data [2.0,1.0,2541.0,4737.0,6089.0,2946.0,5316.0,120.0] belongs to cluster 6
The data [2.0,1.0,12119.0,28326.0,39694.0,4736.0,19410.0,2870.0] belongs to cluster 3
The data [1.0,1.0,3317.0,6602.0,6861.0,1329.0,3961.0,1215.0] belongs to cluster 6
The data [1.0,1.0,2806.0,10765.0,15538.0,1374.0,5828.0,2388.0] belongs to cluster 6
The data [1.0,1.0,1869.0,577.0,572.0,950.0,4762.0,203.0] belongs to cluster 0
The data [1.0,1.0,8656.0,2746.0,2501.0,6845.0,694.0,980.0] belongs to cluster 0
The data [1.0,1.0,2344.0,10678.0,3828.0,1439.0,1566.0,490.0] belongs to cluster 6
The data [1.0,1.0,964.0,4984.0,3316.0,937.0,409.0,7.0] belongs to cluster 0
The data [1.0,1.0,1838.0,6380.0,2824.0,1218.0,1216.0,295.0] belongs to cluster 0
The data [1.0,1.0,7363.0,475.0,585.0,1112.0,72.0,216.0] belongs to cluster 0
The data [1.0,1.0,18226.0,659.0,2914.0,3752.0,586.0,578.0] belongs to cluster 4
The data [1.0,1.0,6202.0,7775.0,10817.0,1183.0,3143.0,1970.0] belongs to cluster 6
The data [1.0,1.0,8885.0,2428.0,1777.0,1777.0,430.0,610.0] belongs to cluster 0
The data [1.0,1.0,13569.0,346.0,489.0,2077.0,44.0,659.0] belongs to cluster 0
The data [1.0,1.0,15671.0,5279.0,2406.0,559.0,562.0,572.0] belongs to cluster 4
The data [1.0,1.0,8040.0,3795.0,2070.0,6340.0,918.0,291.0] belongs to cluster 0
The data [1.0,1.0,29526.0,7961.0,16966.0,432.0,363.0,1391.0] belongs to cluster 4
The data [1.0,1.0,11092.0,5008.0,5249.0,453.0,392.0,373.0] belongs to cluster 0
The data [1.0,1.0,53205.0,4959.0,7336.0,3012.0,967.0,818.0] belongs to cluster 2
The data [1.0,1.0,4720.0,1032.0,975.0,5500.0,197.0,56.0] belongs to cluster 0
The data [1.0,3.0,894.0,1703.0,1841.0,744.0,759.0,1153.0] belongs to cluster 0
The data [1.0,3.0,680.0,1610.0,223.0,862.0,96.0,379.0] belongs to cluster 0
The data [1.0,3.0,9061.0,829.0,683.0,16919.0,621.0,139.0] belongs to cluster 0
The data [1.0,3.0,3366.0,2884.0,2431.0,977.0,167.0,1104.0] belongs to cluster 0
The data [1.0,3.0,68951.0,4411.0,12609.0,8692.0,751.0,2406.0] belongs to cluster 1
The data [1.0,3.0,6022.0,3354.0,3261.0,2507.0,212.0,686.0] belongs to cluster 0
The data [1.0,2.0,444.0,879.0,2060.0,264.0,290.0,259.0] belongs to cluster 0
The data [2.0,2.0,2886.0,5302.0,9785.0,364.0,6236.0,555.0] belongs to cluster 6
The data [2.0,2.0,6468.0,12867.0,21570.0,1840.0,7558.0,1543.0] belongs to cluster 7
The data [1.0,2.0,6987.0,1020.0,3007.0,416.0,257.0,656.0] belongs to cluster 0
The data [1.0,2.0,9784.0,925.0,2405.0,4447.0,183.0,297.0] belongs to cluster 0
The data [1.0,2.0,10617.0,1795.0,7647.0,1483.0,857.0,1233.0] belongs to cluster 0
The data [2.0,2.0,9759.0,25071.0,17645.0,1128.0,12408.0,1625.0] belongs to cluster 7
The data [1.0,2.0,9155.0,1897.0,5167.0,2714.0,228.0,1113.0] belongs to cluster 0
The data [1.0,2.0,15881.0,713.0,3315.0,3703.0,1470.0,229.0] belongs to cluster 4
The data [2.0,3.0,381.0,4025.0,9670.0,388.0,7271.0,1371.0] belongs to cluster 6
The data [2.0,3.0,2320.0,5763.0,11238.0,767.0,5162.0,2158.0] belongs to cluster 6
The data [2.0,3.0,1689.0,6964.0,26316.0,1456.0,15469.0,37.0] belongs to cluster 7
The data [2.0,3.0,27380.0,7184.0,12311.0,2809.0,4621.0,1022.0] belongs to cluster 4
The data [1.0,3.0,3428.0,2380.0,2028.0,1341.0,1184.0,665.0] belongs to cluster 0
The data [2.0,3.0,5981.0,14641.0,20521.0,2005.0,12218.0,445.0] belongs to cluster 7
The data [1.0,3.0,3521.0,1099.0,1997.0,1796.0,173.0,995.0] belongs to cluster 0
The data [1.0,3.0,14039.0,7393.0,2548.0,6386.0,1333.0,2341.0] belongs to cluster 4
The data [1.0,3.0,190.0,727.0,2012.0,245.0,184.0,127.0] belongs to cluster 0
The data [2.0,3.0,37.0,1275.0,22272.0,137.0,6747.0,110.0] belongs to cluster 7
The data [1.0,3.0,759.0,18664.0,1660.0,6114.0,536.0,4100.0] belongs to cluster 6
The data [1.0,3.0,796.0,5878.0,2109.0,340.0,232.0,776.0] belongs to cluster 0
The data [1.0,3.0,19746.0,2872.0,2006.0,2601.0,468.0,503.0] belongs to cluster 4
The data [1.0,3.0,2121.0,1601.0,2453.0,560.0,179.0,712.0] belongs to cluster 0
The data [1.0,3.0,20105.0,1887.0,1939.0,8164.0,716.0,790.0] belongs to cluster 4
The data [1.0,3.0,3884.0,3801.0,1641.0,876.0,397.0,4829.0] belongs to cluster 0
The data [1.0,3.0,6338.0,2256.0,1668.0,1492.0,311.0,686.0] belongs to cluster 0
The data [1.0,3.0,28257.0,944.0,2146.0,3881.0,600.0,270.0] belongs to cluster 4
The data [1.0,3.0,17770.0,4591.0,1617.0,9927.0,246.0,532.0] belongs to cluster 4
The data [1.0,3.0,11635.0,922.0,1614.0,2583.0,192.0,1068.0] belongs to cluster 0
The data [1.0,3.0,20918.0,1916.0,1573.0,1960.0,231.0,961.0] belongs to cluster 4
The data [1.0,3.0,9385.0,1530.0,1422.0,3019.0,227.0,684.0] belongs to cluster 0
The data [1.0,3.0,23632.0,6730.0,3842.0,8620.0,385.0,819.0] belongs to cluster 4
The data [1.0,3.0,4446.0,906.0,1238.0,3576.0,153.0,1014.0] belongs to cluster 0
The data [1.0,3.0,25606.0,11006.0,4604.0,127.0,632.0,288.0] belongs to cluster 4
The data [1.0,3.0,18073.0,4613.0,3444.0,4324.0,914.0,715.0] belongs to cluster 4
The data [1.0,3.0,6884.0,1046.0,1167.0,2069.0,593.0,378.0] belongs to cluster 0
The data [1.0,3.0,97.0,3605.0,12400.0,98.0,2970.0,62.0] belongs to cluster 6
The data [1.0,3.0,8861.0,3783.0,2223.0,633.0,1580.0,1521.0] belongs to cluster 0
The data [2.0,3.0,16980.0,2884.0,12232.0,874.0,3213.0,249.0] belongs to cluster 4
The data [1.0,3.0,11243.0,2408.0,2593.0,15348.0,108.0,1886.0] belongs to cluster 0
The data [1.0,3.0,31012.0,16687.0,5429.0,15082.0,439.0,1163.0] belongs to cluster 2
The data [1.0,3.0,8607.0,1750.0,3580.0,47.0,84.0,2501.0] belongs to cluster 0
The data [1.0,3.0,16731.0,3922.0,7994.0,688.0,2371.0,838.0] belongs to cluster 4
The data [1.0,3.0,29703.0,12051.0,16027.0,13135.0,182.0,2204.0] belongs to cluster 2
The data [2.0,3.0,14531.0,15488.0,30243.0,437.0,14841.0,1867.0] belongs to cluster 7
The data [1.0,3.0,10290.0,1981.0,2232.0,1038.0,168.0,2125.0] belongs to cluster 0
The data [1.0,3.0,2787.0,1698.0,2510.0,65.0,477.0,52.0] belongs to cluster 0
