package com.sxt.scala.lr
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.sql.SQLContext
object MyLinearRegression {
def main(args:Array[String]){
val conf = new SparkConf().setAppName("mylinear").setMaster("local")
val sc = new SparkContext(conf)
val examples = sc.textFile("lpsa.data").map {
line =>
val parts = line.split(",")
val y = parts(0)
val xs = parts(1)
LabeledPoint(parts(0).toDouble,Vectors.dense(parts(1).split(" ").map { _.toDouble}))
}.cache()
/**
* 讲数据进行拆分。分为训练集和测试集
*/
val train2test = examples.randomSplit(Array(0.8,0.2), 1)
val lrs = new LinearRegressionWithSGD()
lrs.setIntercept(true)//设置true-有截距
lrs.optimizer.setNumIterations(100)//设置迭代次数
lrs.optimizer.setMiniBatchFraction(1)//设置每次下山(调整参数后)是否计算所有的样本误差值
lrs.optimizer.setStepSize(1)//设置梯度下降算法的下降步长
lrs.optimizer.setConvergenceTol(0.02)//设置迭代到的误差值
val model = lrs.run(train2test(0))//训练模型
print("截距:"+model.intercept)
print("weights:"+model.weights)
//对样本进行测试
val prediction = model.predict(train2test(1).map {_.features })//等到预测值
val predictionAndLabel = prediction.zip(train2test(1).map { _.label})//(预测值,真实值)
val loss = predictionAndLabel.map({
case (p,v)=>
val error = p-v
Math.abs(error)
}).reduce(_+_)
val error = loss/train2test(1).count()
println("平均误差:"+error)
// model.save(sc, "./mymodel")
// val mymodel = LinearRegressionModel.load(sc, "./mymodel") //读取模型,加载模型
val sqlContext = new SQLContext(sc)
sqlContext.read.parquet("./mymodel/data").show
sc.stop
}
}
lpsa.data里面的数据格式:(逗号分隔之前的是y的值。逗号分隔之后的是一系列的x的值)
-0.1625189,-2.166917 -0.807993 -0.78789619 -1.02470580 -0.52294088 -0.863171185 -1.04215728 -0.8644665
0.7654678,-2.036128 -0.933954 -1.8624259 -1.02470580 -0.522940888 -0.86317118 -1.04215728 -0.8644665073
1.3480731,0.1077859 -1.4722155 0.42094981 -1.0247058 -0.522940 -0.8631711 0.34262704 -0.68718690645