逻辑回归简单的代码(scala实现)

package com.sxt.scala.lr

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.sql.SQLContext

object MyLinearRegression {
  def main(args:Array[String]){
    val conf = new SparkConf().setAppName("mylinear").setMaster("local")
    val sc = new SparkContext(conf)
    
    val examples = sc.textFile("lpsa.data").map {
      line =>
        val parts = line.split(",")
        val y = parts(0)
        val xs = parts(1)
        LabeledPoint(parts(0).toDouble,Vectors.dense(parts(1).split(" ").map { _.toDouble}))
        }.cache()
        
    /**
     * 讲数据进行拆分。分为训练集和测试集
     */
        val train2test = examples.randomSplit(Array(0.8,0.2), 1)
        
        val lrs = new LinearRegressionWithSGD()
        lrs.setIntercept(true)//设置true-有截距
        lrs.optimizer.setNumIterations(100)//设置迭代次数
        lrs.optimizer.setMiniBatchFraction(1)//设置每次下山(调整参数后)是否计算所有的样本误差值
        lrs.optimizer.setStepSize(1)//设置梯度下降算法的下降步长
        lrs.optimizer.setConvergenceTol(0.02)//设置迭代到的误差值
        
        val model = lrs.run(train2test(0))//训练模型
        print("截距:"+model.intercept)
        print("weights:"+model.weights)
        
        //对样本进行测试
        val prediction = model.predict(train2test(1).map {_.features })//等到预测值
        val predictionAndLabel = prediction.zip(train2test(1).map { _.label})//(预测值,真实值)
        val loss = predictionAndLabel.map({
          case (p,v)=>
            val error = p-v
            Math.abs(error)
        }).reduce(_+_)
        val error = loss/train2test(1).count()
        println("平均误差:"+error)
        
//        model.save(sc, "./mymodel")
//        val mymodel = LinearRegressionModel.load(sc, "./mymodel") //读取模型,加载模型
        
       val sqlContext = new SQLContext(sc)
        sqlContext.read.parquet("./mymodel/data").show
        sc.stop
  }

}


lpsa.data里面的数据格式:(逗号分隔之前的是y的值。逗号分隔之后的是一系列的x的值)

-0.1625189,-2.166917 -0.807993 -0.78789619 -1.02470580 -0.52294088 -0.863171185 -1.04215728 -0.8644665
0.7654678,-2.036128 -0.933954 -1.8624259 -1.02470580 -0.522940888 -0.86317118 -1.04215728 -0.8644665073
1.3480731,0.1077859 -1.4722155 0.42094981 -1.0247058 -0.522940 -0.8631711 0.34262704  -0.68718690645

你可能感兴趣的:(机器学习算法)