回归是Spark的机器学习(ML)库提供工具之一。
波士顿房价数据:http://t.cn/RfHTAgY
在统计学中,线性回归(Linear Regression)是利用称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。
val spark = SparkSession.builder().master("local[*]").appName("Boston linear regression").getOrCreate()
val file = spark.read.format("csv")
.option("sep", ",")
.option("header", "true")
.load("boston_house_prices.csv")
/**
* CRIM:城镇人均犯罪率。
*
* ZN:住宅用地超过 25000 sq.ft. 的比例。
*
* INDUS:城镇非零售商用土地的比例。
*
* CHAS:查理斯河空变量(如果边界是河流,则为1;否则为0)。
*
* NOX:一氧化氮浓度。
*
* RM:住宅平均房间数。
*
* AGE:1940 年之前建成的自用房屋比例。
*
* DIS:到波士顿五个中心区域的加权距离。
*
* RAD:辐射性公路的接近指数。
*
* TAX:每 10000 美元的全值财产税率。
*
* PTRATIO:城镇师生比例。
*
* B:1000(Bk-0.63)^ 2,其中 Bk 指代城镇中黑人的比例。
*
* LSTAT:人口中地位低下者的比例。
*
* MEDV:自住房的平均房价,以千美元计。
*/
file.show(false)
import spark.implicits._
//打乱顺序
val rand = new Random()
val data = file.select("MEDV", "CRIM", "ZN", "INDUS", "CHAS", "NOX", "RM", "AGE", "DIS", "RAD", "TAX", "PTRATIO", "B", "LSTAT").map(
row => (row.getAs[String](0).toDouble, row.getString(1).toDouble, row.getString(2).toDouble, row.getString(3).toDouble, row.getString(4).toDouble, row.getString(5).toDouble, row.getString(6).toDouble, row.getString(7).toDouble, row.getString(8).toDouble, row.getString(9).toDouble, row.getString(10).toDouble, row.getString(11).toDouble, row.getString(12).toDouble, row.getString(13).toDouble, rand.nextDouble()))
.toDF("price", "crim", "zn", "indus", "chas", "nox", "rm", "age", "dis", "rad", "tax", "ptratio", "b", "lstat", "rand").sort("rand") //强制类型转换过程
data.show(false)
val ass = new VectorAssembler().setInputCols(Array("crim", "zn", "indus", "chas", "nox", "rm", "age", "dis", "rad", "tax", "ptratio", "b", "lstat")).setOutputCol("features")
val dataset = ass.transform(data) //特征包装
val Array(train, test) = dataset.randomSplit(Array(0.8, 0.2)) //拆分成训练数据集和测试数据集
train.show()
val lr = new LinearRegression().setStandardization(true).setMaxIter(100000)
.setFeaturesCol("features")
.setLabelCol("price")
//创建一个对象
val model = lr.fit(train) //训练
val predict = model.transform(test)
predict.show(false)
+-----+-------+----+-----+----+-----+-----+-----+-------+----+-----+-------+------+-----+--------------------+---------------------------------------------------------------------------+------------------+
|price|crim |zn |indus|chas|nox |rm |age |dis |rad |tax |ptratio|b |lstat|rand |features |prediction |
+-----+-------+----+-----+----+-----+-----+-----+-------+----+-----+-------+------+-----+--------------------+---------------------------------------------------------------------------+------------------+
|19.1 |15.5757|0.0 |18.1 |0.0 |0.58 |5.926|71.0 |2.9084 |24.0|666.0|20.2 |368.74|18.13|0.030943852217344303|[15.5757,0.0,18.1,0.0,0.58,5.926,71.0,2.9084,24.0,666.0,20.2,368.74,18.13] |17.218748400911032|
|16.8 |4.22239|0.0 |18.1 |1.0 |0.77 |5.803|89.0 |1.9047 |24.0|666.0|20.2 |353.04|14.64|0.06794662017900877 |[4.22239,0.0,18.1,1.0,0.77,5.803,89.0,1.9047,24.0,666.0,20.2,353.04,14.64] |20.16152277453722 |
|17.8 |2.33099|0.0 |19.58|0.0 |0.871|5.186|93.8 |1.5296 |5.0 |403.0|14.7 |356.99|28.32|0.07100727334863999 |[2.33099,0.0,19.58,0.0,0.871,5.186,93.8,1.5296,5.0,403.0,14.7,356.99,28.32]|8.861027802620299 |
|10.5 |22.0511|0.0 |18.1 |0.0 |0.74 |5.818|92.4 |1.8662 |24.0|666.0|20.2 |391.45|22.11|0.07510523130875679 |[22.0511,0.0,18.1,0.0,0.74,5.818,92.4,1.8662,24.0,666.0,20.2,391.45,22.11] |12.688516865259075|
|5.0 |38.3518|0.0 |18.1 |0.0 |0.693|5.453|100.0|1.4896 |24.0|666.0|20.2 |396.9 |30.59|0.08927974204437017 |[38.3518,0.0,18.1,0.0,0.693,5.453,100.0,1.4896,24.0,666.0,20.2,396.9,30.59]|6.721596100173919 |
|17.0 |1.41385|0.0 |19.58|1.0 |0.871|6.129|96.0 |1.7494 |5.0 |403.0|14.7 |321.02|15.12|0.09219308651864588 |[1.41385,0.0,19.58,1.0,0.871,6.129,96.0,1.7494,5.0,403.0,14.7,321.02,15.12]|21.223825980332098|
|24.1 |0.0795 |60.0|1.69 |0.0 |0.411|6.579|35.9 |10.7103|4.0 |411.0|18.3 |370.78|5.49 |0.12572282189123463 |[0.0795,60.0,1.69,0.0,0.411,6.579,35.9,10.7103,4.0,411.0,18.3,370.78,5.49] |20.236556022708868|
|43.8 |0.08187|0.0 |2.89 |0.0 |0.445|7.82 |36.9 |3.4952 |2.0 |276.0|18.0 |393.53|3.57 |0.13252065137243751 |[0.08187,0.0,2.89,0.0,0.445,7.82,36.9,3.4952,2.0,276.0,18.0,393.53,3.57] |34.25851592117899 |
|42.8 |0.36894|22.0|5.86 |0.0 |0.431|8.259|8.4 |8.9067 |7.0 |330.0|19.1 |396.9 |3.54 |0.15850315005891624 |[0.36894,22.0,5.86,0.0,0.431,8.259,8.4,8.9067,7.0,330.0,19.1,396.9,3.54] |27.66779959608032 |
|20.3 |0.14103|0.0 |13.92|0.0 |0.437|5.79 |58.0 |6.32 |4.0 |289.0|16.0 |396.9 |15.84|0.17069545065789116 |[0.14103,0.0,13.92,0.0,0.437,5.79,58.0,6.32,4.0,289.0,16.0,396.9,15.84] |19.003634938362655|
|20.6 |0.04527|0.0 |11.93|0.0 |0.573|6.12 |76.7 |2.2875 |1.0 |273.0|21.0 |396.9 |9.08 |0.1731772014604226 |[0.04527,0.0,11.93,0.0,0.573,6.12,76.7,2.2875,1.0,273.0,21.0,396.9,9.08] |22.579748218210472|
|22.0 |0.01096|55.0|2.25 |0.0 |0.389|6.453|31.9 |7.3073 |1.0 |300.0|15.3 |394.72|8.23 |0.1724119532213404 |[0.01096,55.0,2.25,0.0,0.389,6.453,31.9,7.3073,1.0,300.0,15.3,394.72,8.23] |27.58621573155951 |
|21.4 |0.16902|0.0 |25.65|0.0 |0.581|5.986|88.4 |1.9929 |2.0 |188.0|19.1 |385.02|14.81|0.1799155020953298 |[0.16902,0.0,25.65,0.0,0.581,5.986,88.4,1.9929,2.0,188.0,19.1,385.02,14.81]|22.246679248091347|
|44.8 |0.31533|0.0 |6.2 |0.0 |0.504|8.266|78.3 |2.8944 |8.0 |307.0|17.4 |385.05|4.14 |0.1840073770973527 |[0.31533,0.0,6.2,0.0,0.504,8.266,78.3,2.8944,8.0,307.0,17.4,385.05,4.14] |37.47600066140883 |
|22.2 |0.24103|0.0 |7.38 |0.0 |0.493|6.083|43.7 |5.4159 |5.0 |287.0|19.6 |396.9 |12.79|0.19819278048259803 |[0.24103,0.0,7.38,0.0,0.493,6.083,43.7,5.4159,5.0,287.0,19.6,396.9,12.79] |18.846125680483382|
|17.3 |0.15038|0.0 |25.65|0.0 |0.581|5.856|97.0 |1.9444 |2.0 |188.0|19.1 |370.31|25.41|0.20165114056013633 |[0.15038,0.0,25.65,0.0,0.581,5.856,97.0,1.9444,2.0,188.0,19.1,370.31,25.41]|15.637800574247049|
|21.1 |0.29916|20.0|6.96 |0.0 |0.464|5.856|42.1 |4.429 |3.0 |223.0|18.6 |388.65|13.0 |0.21196533948542517 |[0.29916,20.0,6.96,0.0,0.464,5.856,42.1,4.429,3.0,223.0,18.6,388.65,13.0] |22.612442612223248|
|37.6 |0.38214|0.0 |6.2 |0.0 |0.504|8.04 |86.5 |3.2157 |8.0 |307.0|17.4 |387.38|3.13 |0.2221818318573724 |[0.38214,0.0,6.2,0.0,0.504,8.04,86.5,3.2157,8.0,307.0,17.4,387.38,3.13] |36.95679763969078 |
|24.1 |0.03445|82.5|2.03 |0.0 |0.415|6.162|38.4 |6.27 |2.0 |348.0|14.7 |393.77|7.43 |0.22497541942266708 |[0.03445,82.5,2.03,0.0,0.415,6.162,38.4,6.27,2.0,348.0,14.7,393.77,7.43] |30.53145825542635 |
|20.3 |0.08387|0.0 |12.83|0.0 |0.437|5.874|36.6 |4.5026 |5.0 |398.0|18.7 |396.06|9.1 |0.24061582791393432 |[0.08387,0.0,12.83,0.0,0.437,5.874,36.6,4.5026,5.0,398.0,18.7,396.06,9.1] |22.545830610986005|
+-----+-------+----+-----+----+-----+-----+-----+-------+----+-----+-------+------+-----+--------------------+---------------------------------------------------------------------------+------------------+
MSE、RMSE、MAE、R Squared等评测指标
val mse_evaluator = new RegressionEvaluator().setMetricName("mse").setLabelCol("price").setPredictionCol("prediction")
val mse = mse_evaluator.evaluate(predict)
println("mse : " + mse) // mse : 19.39618580712659
val rmse_evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("price").setPredictionCol("prediction")
val rmse = rmse_evaluator.evaluate(predict)
println("rmse : " + rmse) //rmse : 4.404110103883257
val mae_evaluator = new RegressionEvaluator().setMetricName("mae").setLabelCol("price").setPredictionCol("prediction")
val mae = mae_evaluator.evaluate(predict)
println("mae : " + mae) // mae : 3.236373824089298
val r2_evaluator = new RegressionEvaluator().setMetricName("r2").setLabelCol("price").setPredictionCol("prediction")
val r2 = r2_evaluator.evaluate(predict)
println("r2 : " + r2) //r2 : 0.7997829281347897
println("------系数和截距-----------")
println("系数:" + model.coefficients) //系数:[-0.09906791915729575,0.06271312643290401,-0.007588177630151601,2.327008944407733,-20.95397395532347,2.9003657846148996,0.009083827305918606,-1.7045934892819223,0.33635604026919086,-0.012432940699096735,-0.9597001686441345,0.007337818686994463,-0.5930690692483809]
println("截距:" + model.intercept) //截距:45.81790373298227