scala实现Kmeans算法

  好久没有写博客了,虽然并没有多少人看。kmeans的思想大家自己去查找,我就不一一叙述了。kmeans之所以不能达到全局最优,是因为他的cost函数是一个非凸的函数,找不到最低点那个位置。kmeans的初始位置很重要,本片博客采取的就是最基本的随机生成初始中心点(我很好奇,有些人的代码就是随机生成n和点,都不带判重的),比较 好的生成算法是kmeans++,保证初始点间的距离最远。这是我初学scala一个月写的代码,还没有体会到scala的精髓,望各位指导!

import scala.collection.immutable.Vector
import scala.io.Source
import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import com.sun.jersey.core.spi.factory.MessageBodyFactory.DistanceComparator

object MyKmeans {
   //读取需要聚类的数据
   def GetData(pathfile:String):Array[Point]={
    val source=Source.fromFile(pathfile)
    val lines=source.getLines().toArray
    val data=lines.map { x =>new Point( x.split(" ").map { y => y.toDouble})}
    println("读取数据完成")
    data
  }
  def main(args: Array[String]): Unit = {
    var data=GetData("/home/hadoop/xixi")
    var k=new Kmeans(data,5,20)
    k.run()
    k.SaveData
  }
}

import scala.collection.immutable.Vector
import scala.io.Source
import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import com.sun.jersey.core.spi.factory.MessageBodyFactory.DistanceComparator
import org.apache.spark.mllib.util.Saveable
import java.io.PrintWriter

class Kmeans(val data:Array[Point],val numClusters:Int,val MaxIterations:Int,
    val threshold:Double=1e-4,val savepath:String="/home/hadoop/haha") {
  //中心点的坐标
  var CenterPoint=new Array[Point](numClusters)
  //每个点对应的中心点相关信息
  var Costinform=new Array[Vedist](data.length)
  //构造出一个长度为len的Point数组,Point的各个量为0 ,Point为k维度
  def InitArrPoint(len:Int,k:Int):Array[Point]={
    var arr=new Array[Double](k)
    var arrp=new Array[Point](len)
    arrp.map { x => new Point(arr)}
  }
  //输出该数据结构中的数据,便于调试使用
  def Output(data:Array[Point])
  {
    data.foreach { x => x.OutPut}
  }
  //获取初始的中心点
  def InitCenterPoint(){
    var ve=new ArrayBuffer[Double]
    val st=System.nanoTime()
    var n=0
    while(ve.lengthx==a})
      {
        ve+=a
        CenterPoint(n)=data(a)
        n+=1
      }
    }
    val ed=System.nanoTime()
    println("--------------------------------------\n随机中心点已经生成成功,生成时间为:"+(ed-st)+"\n随机点为:")
    Output(CenterPoint)
  }
  //找到一个点距离最近的中心
  def FastSearch(point:Point,n:Int):Vedist=
  {
    var cost=Double.MaxValue
    var k = -1
    for(i<-(0 until CenterPoint.length))
    {
      var m=point.Distance(CenterPoint(i))
      if(cost>m)
      {
        cost=m
        k=i
      }
    }
    val m=Vedist(k,cost)
    Costinform(n)=m
    m
  }
  //设置中心点坐标
  def setCenterPoint(NewPoint:Array[Point])
  {
    for(i<- 0 until numClusters)
      CenterPoint(i)=NewPoint(i)
  }
  //计算损失函数
  def ComputeCost:Double=
  {
    var sum=0.0
     Costinform.foreach { x =>sum+=x.cost}
    sum
  }
  //kmeans函数运行主体
  def run()
  {
    InitCenterPoint()
    var k=0
    var f=true
    val st=System.nanoTime()
    while(kf._1.Distance(f._2)).exists {_>threshold}
      //如果符合条件则继续更新计算
      if(!f)
      {
        for(i<-0 until numClusters)
        CenterPoint(i)=NewPoint(i)
        println("第"+k+"次中心点")
        Output(CenterPoint)
        println("第"+k+"次花费")
        println(ComputeCost)
      }
    }
    val ed=System.nanoTime()
    println("Kmeans聚类时间为:"+(ed-st))
  }
  //保存数据
  def SaveData{
    val out=new PrintWriter(savepath)
    out.println("中心点为:")
    for(i<- 0 until CenterPoint.length)
      out.println(CenterPoint(i).mkString)
    out.println("花费为:")
    out.println(ComputeCost)
    out.println("各个点属于")
    for(i<- 0 until data.length)
      out.println(data(i).mkString+"          "+Costinform(i).center_id)
    out.close()
  }
}

import scala.collection.mutable.ArrayBuffer 
import parquet.org.codehaus.jackson.map.ser.impl.PropertySerializerMap.Empty

//定义点类
  class Point(val px:Vector[Double]){
    def this(p:Array[Double])
    {
      this(p.toVector)
    }
    def OutPut
    {
      px.foreach {x=>print(x+" ")}
      println
    }
    def ^ :Double={
      px.map { x => x*x }.sum
    }
    def +(that:Point):Point={
      var m=new ArrayBuffer[Double]
      for(i<-0 until px.length)
        m+=(px(i)+that.px(i))
      new Point(m.toArray)
    }
    def *(that:Point):Double={
      var m=0.0
      for(i<-0 until px.length)
        m=m+px(i)*that.px(i)
      m
    }
    def /(n:Int):Point={
      var ve=new Array[Double](px.length)
      for(i<-0 until px.length)
        ve(i)=px(i)/n
      new Point(ve)
    }
    def Distance(that:Point):Double=(this^ )+(that^ )-2*(that*this)
    def init(len:Int):Point={
      new Point(new Array[Double](len))
    }
    def mkString:String={
      var str=""
      px.foreach { x =>str+=x.toString()+" " }
      str
    }
  }

//单纯的储存信息的case类,center_id代表数据点对应的中心点,cost代表两点的花费
case class Vedist(val center_id:Int,val cost:Double)
scala写的非常不地道,没有发挥函数式编程的优越性

你可能感兴趣的:(spark,scala,机器学习)