Spark高斯混合模型

聚类数据源下载地址 :http://download.csdn.net/detail/wguangliang/9595795

提供local单机测试代码,如下:


import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.clustering.GaussianMixture
import org.apache.spark.mllib.linalg.Vectors
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.mllib.regression.LabeledPoint
import java.io.{FileWriter}

object GaussianMixtureTest {
    def main(args: Array[String]): Unit = {
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
        //    if (args.length < 2) {
        //      println("usage: DenseGmmEM   [maxIterations]")
        //    } else {
        //      val maxIterations = if (args.length > 2) args(2).toInt else 100
        //      run(args(0), args(1).toInt, maxIterations)
        //    }
//        var max = 0.0
//        var maxIter = 0
//        for(i<- List(1,10,20,30,40,50,60,70,80,90,100)) {
//            val corr = run("C:\\Users\\qingjian\\Desktop\\sf.txt", 2, i)
//            if(corr>max) {
//                max = corr
//                maxIter = i
//            }
//        }
//        println(maxIter+":"+max)
//         run("C:\\Users\\qingjian\\Desktop\\sf.txt",2,3)
         run("C:\\Users\\qingjian\\Desktop\\Result_data.txt",2,4)
    }
    /**
     * 输入文件路径,聚类个数[默认2个],最大迭代次数[默认100次]
     */
    private def run(inputFile: String, k: Int, maxIterations: Int)= {
        //    val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
        val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example").setMaster("local[4]")
        val ctx = new SparkContext(conf)

        val data = ctx.textFile(inputFile).map { line =>

            val split = line.trim.split('\t').map(_.toDouble)
            Vectors.dense(split.init)
        }.cache()
        val dataWithLabel = ctx.textFile(inputFile).map { line =>

            val split = line.trim.split('\t').map(_.toDouble)
            LabeledPoint(split.last, Vectors.dense(split.init))
        }.cache()


        val clusters = new GaussianMixture()
            .setK(k)
            .setMaxIterations(maxIterations)
            .run(data)

        /*    显示分类概率
        for (i <- 0 until clusters.k) {
            println("weight=%f\nmu=%s\nsigma=\n%s\n" format
                (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
        }

        println("The membership value of each vector to all mixture components (first <= 100):")

        val membership = clusters.predictSoft(data)
        membership.foreach { x =>
            println(" " + x.mkString(","))
        }*/

        val prediction = clusters.predict(dataWithLabel.map(_.features))
        val predictionWithLabel = dataWithLabel.zip(prediction)
        ///显示预测信息
        predictionWithLabel.collect.map { x =>
            println("预测" + x._1 + ":" + x._2)
        }

      val writer = new FileWriter("C:\\Users\\qingjian\\Desktop\\clustered_result.txt",false)
      val sourceData = predictionWithLabel.map(x=>(x._1.label,1)).reduceByKey(_+_)
      writer.append("原始数据\n")
      sourceData.collect.foreach(x => writer.append(x._1 + "\t" + x._2 + "(" + x._2 / predictionWithLabel.count.toDouble + ")\n"))



      val data1TopClusterLabel = predictionWithLabel.map(x=>(x._2,1)).reduceByKey(_+_).sortBy(_._2, false).first //作弊车的聚类标签

      writer.append("\nClustered Instances 聚类结果["+data1TopClusterLabel._1+"代表正常数据的类]\n")
      val clusteredInstance = predictionWithLabel.map(x => (x._2, 1)).reduceByKey(_ + _)
      clusteredInstance.collect.foreach(x => writer.append(x._1 + "\t" + x._2 + "(" + x._2 / predictionWithLabel.count.toDouble + ")\n"))
      writer.append("\n")
      clusteredInstance.collect.foreach(x => writer.append(x._1 + "\t"))
      writer.append("<--assigned to cluster\n")
      val correctPrediction = predictionWithLabel.filter(x => x._1.label == x._2)
      val errorPrediction = predictionWithLabel.filter(x => x._1.label != x._2)
        for (i <- 0 until k) {

            for (j <- 0 until k) {
                if (i == j) {
                    writer.append(correctPrediction.filter(_._1.label.toInt == i).count + "\t")
                } else {
                    writer.append(errorPrediction.filter(x => x._2 == j).count + "\t")
                }

            }
            writer.append(""+i)
            
            writer.append("\n")
        }

        val err = errorPrediction.count / predictionWithLabel.count.toDouble
        if(data1TopClusterLabel._1.toInt!=0) {
          writer.append("Incorrectly clusterd instances : " + errorPrediction.count + "\t" + (1-err))
        } else {
          writer.append("Incorrectly clusterd instances : " + errorPrediction.count + "\t" + err)

        }
        writer.close
        ctx.stop()
        1 - err //返回正确率
        
        
    }

}


结果如下:

原始数据
0.0	514(0.514)
1.0	486(0.486)


Clustered Instances 聚类结果[0代表正常数据的类]
0	594(0.594)
1	406(0.406)


0	1	<--assigned to cluster
514	0	0
80	406	1
Incorrectly clusterd instances : 80	0.08



你可能感兴趣的:(数据挖掘,Spark)