回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多重共线性情况下运行良好。
数学上,ElasticNet被定义为L1和L2正则化项的凸组合:
通过适当设置α,ElasticNet包含L1和L2正则化作为特殊情况。例如,如果用参数α设置为1来训练线性回归模型,则其等价于Lasso模型。另一方面,如果α被设置为0,则训练的模型简化为ridge回归模型。
RegParam:lambda>=0
ElasticNetParam:alpha in [0, 1]
导入包
import
org.apache.spark.sql.SparkSession
import
org.apache.spark.sql.Dataset
import
org.apache.spark.sql.Row
import
org.apache.spark.sql.DataFrame
import
org.apache.spark.sql.Column
import
org.apache.spark.sql.DataFrameReader
import
org.apache.spark.rdd.RDD
import
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import
org.apache.spark.sql.Encoder
import
org.apache.spark.sql.DataFrameStatFunctions
import
org.apache.spark.sql.functions.
_
import
org.apache.spark.ml.linalg.Vectors
import
org.apache.spark.ml.feature.VectorAssembler
import
org.apache.spark.ml.evaluation.RegressionEvaluator
import
org.apache.spark.ml.regression.LinearRegression
// Population人口,
// Income收入水平,
// Illiteracy文盲率,
// LifeExp,
// Murder谋杀率,
// HSGrad,
// Frost结霜天数(温度在冰点以下的平均天数) ,
// Area州面积
val
spark
=
SparkSession.builder().appName(
"Spark Linear Regression"
).config(
"spark.some.config.option"
,
"some-value"
).getOrCreate()
// For implicit conversions like converting RDDs to DataFrames
import
spark.implicits.
_
val
dataList
:
List[(Double, Double, Double, Double, Double, Double, Double, Double)]
=
List(
(
3615
,
3624
,
2.1
,
69.05
,
15.1
,
41.3
,
20
,
50708
),
(
365
,
6315
,
1.5
,
69.31
,
11.3
,
66.7
,
152
,
566432
),
(
2212
,
4530
,
1.8
,
70.55
,
7.8
,
58.1
,
15
,
113417
),
(
2110
,
3378
,
1.9
,
70.66
,
10.1
,
39.9
,
65
,
51945
),
(
21198
,
5114
,
1.1
,
71.71
,
10.3
,
62.6
,
20
,
156361
),
(
2541
,
4884
,
0.7
,
72.06
,
6.8
,
63.9
,
166
,
103766
),
(
3100
,
5348
,
1.1
,
72.48
,
3.1
,
56
,
139
,
4862
),
(
579
,
4809
,
0.9
,
70.06
,
6.2
,
54.6
,
103
,
1982
),
(
8277
,
4815
,
1.3
,
70.66
,
10.7
,
52.6
,
11
,
54090
),
(
4931
,
4091
,
2
,
68.54
,
13.9
,
40.6
,
60
,
58073
),
(
868
,
4963
,
1.9
,
73.6
,
6.2
,
61.9
,
0
,
6425
),
(
813
,
4119
,
0.6
,
71.87
,
5.3
,
59.5
,
126
,
82677
),
(
11197
,
5107
,
0.9
,
70.14
,
10.3
,
52.6
,
127
,
55748
),
(
5313
,
4458
,
0.7
,
70.88
,
7.1
,
52.9
,
122
,
36097
),
(
2861
,
4628
,
0.5
,
72.56
,
2.3
,
59
,
140
,
55941
),
(
2280
,
4669
,
0.6
,
72.58
,
4.5
,
59.9
,
114
,
81787
),
(
3387
,
3712
,
1.6
,
70.1
,
10.6
,
38.5
,
95
,
39650
),
(
3806
,
3545
,
2.8
,
68.76
,
13.2
,
42.2
,
12
,
44930
),
(
1058
,
3694
,
0.7
,
70.39
,
2.7
,
54.7
,
161
,
30920
),
(
4122
,
5299
,
0.9
,
70.22
,
8.5
,
52.3
,
101
,
9891
),
(
5814
,
4755
,
1.1
,
71.83
,
3.3
,
58.5
,
103
,
7826
),
(
9111
,
4751
,
0.9
,
70.63
,
11.1
,
52.8
,
125
,
56817
),
(
3921
,
4675
,
0.6
,
72.96
,
2.3
,
57.6
,
160
,
79289
),
(
2341
,
3098
,
2.4
,
68.09
,
12.5
,
41
,
50
,
47296
),
(
4767
,
4254
,
0.8
,
70.69
,
9.3
,
48.8
,
108
,
68995
),
(
746
,
4347
,
0.6
,
70.56
,
5
,
59.2
,
155
,
145587
),
(
1544
,
4508
,
0.6
,
72.6
,
2.9
,
59.3
,
139
,
76483
),
(
590
,
5149
,
0.5
,
69.03
,
11.5
,
65.2
,
188
,
109889
),
(
812
,
4281
,
0.7
,
71.23
,
3.3
,
57.6
,
174
,
9027
),
(
7333
,
5237
,
1.1
,
70.93
,
5.2
,
52.5
,
115
,
7521
),
(
1144
,
3601
,
2.2
,
70.32
,
9.7
,
55.2
,
120
,
121412
),
(
18076
,
4903
,
1.4
,
70.55
,
10.9
,
52.7
,
82
,
47831
),
(
5441
,
3875
,
1.8
,
69.21
,
11.1
,
38.5
,
80
,
48798
),
(
637
,
5087
,
0.8
,
72.78
,
1.4
,
50.3
,
186
,
69273
),
(
10735
,
4561
,
0.8
,
70.82
,
7.4
,
53.2
,
124
,
40975
),
(
2715
,
3983
,
1.1
,
71.42
,
6.4
,
51.6
,
82
,
68782
),
(
2284
,
4660
,
0.6
,
72.13
,
4.2
,
60
,
44
,
96184
),
(
11860
,
4449
,
1
,
70.43
,
6.1
,
50.2
,
126
,
44966
),
(
931
,
4558
,
1.3
,
71.9
,
2.4
,
46.4
,
127
,
1049
),
(
2816
,
3635
,
2.3
,
67.96
,
11.6
,
37.8
,
65
,
30225
),
(
681
,
4167
,
0.5
,
72.08
,
1.7
,
53.3
,
172
,
75955
),
(
4173
,
3821
,
1.7
,
70.11
,
11
,
41.8
,
70
,
41328
),
(
12237
,
4188
,
2.2
,
70.9
,
12.2
,
47.4
,
35
,
262134
),
(
1203
,
4022
,
0.6
,
72.9
,
4.5
,
67.3
,
137
,
82096
),
(
472
,
3907
,
0.6
,
71.64
,
5.5
,
57.1
,
168
,
9267
),
(
4981
,
4701
,
1.4
,
70.08
,
9.5
,
47.8
,
85
,
39780
),
(
3559
,
4864
,
0.6
,
71.72
,
4.3
,
63.5
,
32
,
66570
),
(
1799
,
3617
,
1.4
,
69.48
,
6.7
,
41.6
,
100
,
24070
),
(
4589
,
4468
,
0.7
,
72.48
,
3
,
54.5
,
149
,
54464
),
(
376
,
4566
,
0.6
,
70.29
,
6.9
,
62.9
,
173
,
97203
))
val
data
=
dataList.toDF(
"Population"
,
"Income"
,
"Illiteracy"
,
"LifeExp"
,
"Murder"
,
"HSGrad"
,
"Frost"
,
"Area"
)
建立线性回归模型
val
colArray
=
Array(
"Population"
,
"Income"
,
"Illiteracy"
,
"LifeExp"
,
"HSGrad"
,
"Frost"
,
"Area"
)
val
assembler
=
new
VectorAssembler().setInputCols(colArray).setOutputCol(
"features"
)
val
vecDF
:
DataFrame
=
assembler.transform(data)
// 建立模型,预测谋杀率Murder
// 设置线性回归参数
val
lr
1
=
new
LinearRegression()
val
lr
2
=
lr
1
.setFeaturesCol(
"features"
).setLabelCol(
"Murder"
).setFitIntercept(
true
)
// RegParam:正则化
val
lr
3
=
lr
2
.setMaxIter(
10
).setRegParam(
0.3
).setElasticNetParam(
0.8
)
val
lr
=
lr
3
// Fit the model
val
lrModel
=
lr.fit(vecDF)
// 输出模型全部参数
lrModel.extractParamMap()
// Print the coefficients and intercept for linear regression
println(s
"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"
)
val
predictions
=
lrModel.transform(vecDF)
predictions.selectExpr(
"Murder"
,
"round(prediction,1) as prediction"
).show
// Summarize the model over the training set and print out some metrics
val
trainingSummary
=
lrModel.summary
println(s
"numIterations: ${trainingSummary.totalIterations}"
)
println(s
"objectiveHistory: ${trainingSummary.objectiveHistory.toList}"
)
trainingSummary.residuals.show()
println(s
"RMSE: ${trainingSummary.rootMeanSquaredError}"
)
println(s
"r2: ${trainingSummary.r2}"
)
代码执行结果
// 输出模型全部参数
lrModel.extractParamMap()
res
15
:
org.apache.spark.ml.param.ParamMap
=
{
linReg
_
2
ba
28140
e
39
a-elasticNetParam
:
0.8
,
linReg
_
2
ba
28140
e
39
a-featuresCol
:
features,
linReg
_
2
ba
28140
e
39
a-fitIntercept
:
true
,
linReg
_
2
ba
28140
e
39
a-labelCol
:
Murder,
linReg
_
2
ba
28140
e
39
a-maxIter
:
10
,
linReg
_
2
ba
28140
e
39
a-predictionCol
:
prediction,
linReg
_
2
ba
28140
e
39
a-regParam
:
0.3
,
linReg
_
2
ba
28140
e
39
a-solver
:
auto,
linReg
_
2
ba
28140
e
39
a-standardization
:
true
,
linReg
_
2
ba
28140
e
39
a-tol
:
1.0
E-
6
}
// Print the coefficients and intercept for linear regression
println(s
"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"
)
Coefficients
:
[
1.36662199778084
E-
4
,
0.0
,
1.1834384307116244
,-
1.4580829641757522
,
0.0
,-
0.010686434270049252
,
4.051355050528196
E-
6
] Intercept
:
109.589659881471
val
predictions
=
lrModel.transform(vecDF)
predictions
:
org.apache.spark.sql.DataFrame
=
[Population
:
double, Income
:
double ...
8
more fields]
predictions.selectExpr(
"Murder"
,
"round(prediction,1) as prediction"
).show
+------+----------+
|Murder|prediction|
+------+----------+
|
15.1
|
11.9
|
|
11.3
|
11.0
|
|
7.8
|
9.5
|
|
10.1
|
8.6
|
|
10.3
|
9.6
|
|
6.8
|
4.3
|
|
3.1
|
4.2
|
|
6.2
|
7.5
|
|
10.7
|
9.3
|
|
13.9
|
12.3
|
|
6.2
|
4.7
|
|
5.3
|
4.6
|
|
10.3
|
8.8
|
|
7.1
|
6.6
|
|
2.3
|
3.5
|
|
4.5
|