K 聚类分析实现类源码

数据文件来自 :http://archive.ics.uci.edu/ml/datasets/Wholesale+customers?cm_mc_uid=21918109261714715776095&cm_mc_sid_50200000=1476090999


import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
object KMeansClustering {
 def main (args: Array[String]) {
 if (args.length < 5) {

    println("Usage:KMeansClustering trainingDataFilePath testDataFilePath numClusters
    numIterations runTimes")
 sys.exit(1)
 }

 val conf = new
    SparkConf().setAppName("Spark MLlib Exercise:K-Means Clustering")
 val sc = new SparkContext(conf)

 
 val rawTrainingData = sc.textFile(args(0))
 val parsedTrainingData =
    rawTrainingData.filter(!isColumnNameLine(_)).map(line => {

    Vectors.dense(line.split("\t").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
 }).cache()

    // Cluster the data into two classes using KMeans

 val numClusters = args(2).toInt
 val numIterations = args(3).toInt
 val runTimes = args(4).toInt
 var clusterIndex:Int = 0
 val clusters:KMeansModel =
    KMeans.train(parsedTrainingData, numClusters, numIterations,runTimes)

    println("Cluster Number:" + clusters.clusterCenters.length)

    println("Cluster Centers Information Overview:")
 clusters.clusterCenters.foreach(
    x => {

    println("Center Point of Cluster " + clusterIndex + ":")

    println(x)
 clusterIndex += 1
 })

    //begin to check which cluster each test data belongs to based on the clustering result

 val rawTestData = sc.textFile(args(1))
 val parsedTestData = rawTestData.map(line =>
    {

    Vectors.dense(line.split("\t").map(_.trim).filter(!"".equals(_)).map(_.toDouble))

    })
 parsedTestData.collect().foreach(testDataLine => {
 val predictedClusterIndex:
    Int = clusters.predict(testDataLine)

    println("The data " + testDataLine.toString + " belongs to cluster " +
    predictedClusterIndex)
 })

    println("Spark MLlib K-means clustering test finished.")
 }

 private def isColumnNameLine(line:String):Boolean = {
 if (line != null &&
    line.contains("Channel")) true
 else false
 }
 


你可能感兴趣的:(Spark,ML)