算法小白的第一次尝试---PLA(感知机算法)实现

import breeze.linalg.DenseVector
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
/**
  * @author XiaoTangBao
  * @date 2019/3/6 10:13
  * @version 1.0
  * The original form of perceptron learning algorithm,For linearly separable data sets,
  * the original form of the perceptron algorithm converges. After a finite number of iterations,
  * a hyperplane can be found, and the data set is completely correctly divided.
  */
object PLA {
  def main(args: Array[String]): Unit = {
    //屏蔽部分日志
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    val sparkSession = SparkSession.builder().master("local[4]").appName("PLA").getOrCreate()
    //获取数据源----https://pan.baidu.com/s/17dK9fdGHzGY1SfI-s1pt6w
    val data = sparkSession.sparkContext.textFile("G:\\mldata\\iris.txt")
    val pddata = data.map(str => str.split('|')).map(arr =>(arr(0).toDouble,arr(1).toDouble,arr(2).toDouble,arr(3).toDouble,arr(4))).collect()
    val Xi = ArrayBuffer[LabeledPoint]()
    for(dt<- pddata) {
      var label = -1
      if(dt._5.equals("Iris-setosa")) label = 1
      Xi.append(LabeledPoint(label,Vectors.dense(dt._1,dt._2,dt._3,dt._4)))
    }
    //待分类点为四维,定义初始的w,b,ata
    var w = DenseVector(3.0,0.8,1.8,2.4)
    var b = 20.0
    var ata = 0.2
    //基于SGD迭代求解最优w,b,ata
    var outflag = true
    while(outflag){
      for(i<-0 until  Xi.length){
        var inflag = true
        //针对当前分类错误点,不停的修改超平面,直至该点分类正确
        while(inflag){
          if(!judge(w,b,Xi(i))){
            println("当前纠正:X"+(i+1))
            w = ata * Xi(i).label * DenseVector(Xi(i).features.toArray) + w
            b = b + ata * Xi(i).label
          }else{
            inflag = false
          }
        }
      }
      //所有测点都完全分类正确,则退出
      var num = 0
      for(i<-0 until Xi.length){
        if(judge(w,b,Xi(i))) num +=1
      }
      if(num == Xi.length) outflag = false
    }
    println("训练结束")
    println(w)
    println(b)
  }
  //判断是否被正确分类
  def judge(w:DenseVector[Double],b:Double,xi:LabeledPoint):Boolean = {
    var flag = true
    //(w dot DenseVector(xi.features.toArray)) 必须添加优先级(),不然报错
    val fit = xi.label * ((w dot DenseVector(xi.features.toArray)) + b )
    if(fit <=0) flag = false
    flag
  }
}
--------------------------------------------------------------------
当前纠正:X51
当前纠正:X51
当前纠正:X51
当前纠正:X1
当前纠正:X52
当前纠正:X1
当前纠正:X58
当前纠正:X1
当前纠正:X58
当前纠正:X1
当前纠正:X58
当前纠正:X1
当前纠正:X99
当前纠正:X1
训练结束
DenseVector(-1.7200000000000006, -0.1399999999999999, -3.760000000000001, 0.40000000000000024)
19.400000000000002

你可能感兴趣的:(机器学习,scala,算法,Spark,小白的算法之路)