文章目录
- 《Spark高级数据分析》——预测森林植被(决策树、随机森林)
- 0. 简介
- 1. 数据准备
- 2. 训练决策树模型
- 3. 预测森林植被
- 4. 利用网格搜索与交叉验证API
- 5. 随机森林模型
- 6. 完整代码
《Spark高级数据分析》——预测森林植被(决策树、随机森林)
0. 简介
- 来源: 《Spark高级数据分析》
- 原书GitHub地址: https://github.com/sryza/aas
- 内容简述:利用Spark中的决策树、随机森林算法,预测不同类型的森林植被
1. 数据准备
- 读取森林植被特征、标签数据 covtype.data
val dataDF = loadData(spark)
dataDF.show()
def loadData(spark: SparkSession): DataFrame = {
import spark.implicits._
val dataWithoutHeaderDF = spark.read
.option("inferSchema", true)
.option("header", false)
.csv("E:/Data/saa/Chapter4_covtype/covtype.data")
// 重新定义字段名
val colNames = Seq(
"Elevation", "Aspect", "Slope",
"Horizontal_Distance_To_Hydrology", "Vertical_Distance_To_Hydrology",
"Horizontal_Distance_To_Roadways",
"Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
"Horizontal_Distance_To_Fire_Points") ++
(0 until 4).map(i => s"Wilderness_Area_$i") ++
(0 until 40).map(i => s"Soil_Type_$i") ++
Seq("Cover_Type")
dataWithoutHeaderDF.toDF(colNames: _*)
.withColumn("Cover_Type", $"Cover_Type".cast("double"))
}
+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+-----------------+-----------------+-----------------+-----------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------+
|Elevation|Aspect|Slope|Horizontal_Distance_To_Hydrology|Vertical_Distance_To_Hydrology|Horizontal_Distance_To_Roadways|Hillshade_9am|Hillshade_Noon|Hillshade_3pm|Horizontal_Distance_To_Fire_Points|Wilderness_Area_0|Wilderness_Area_1|Wilderness_Area_2|Wilderness_Area_3|Soil_Type_0|Soil_Type_1|Soil_Type_2|Soil_Type_3|Soil_Type_4|Soil_Type_5|Soil_Type_6|Soil_Type_7|Soil_Type_8|Soil_Type_9|Soil_Type_10|Soil_Type_11|Soil_Type_12|Soil_Type_13|Soil_Type_14|Soil_Type_15|Soil_Type_16|Soil_Type_17|Soil_Type_18|Soil_Type_19|Soil_Type_20|Soil_Type_21|Soil_Type_22|Soil_Type_23|Soil_Type_24|Soil_Type_25|Soil_Type_26|Soil_Type_27|Soil_Type_28|Soil_Type_29|Soil_Type_30|Soil_Type_31|Soil_Type_32|Soil_Type_33|Soil_Type_34|Soil_Type_35|Soil_Type_36|Soil_Type_37|Soil_Type_38|Soil_Type_39|Cover_Type|
+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+-----------------+-----------------+-----------------+-----------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------+
| 2596| 51| 3| 258| 0| 510| 221| 232| 148| 6279| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2590| 56| 2| 212| -6| 390| 220| 235| 151| 6225| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2804| 139| 9| 268| 65| 3180| 234| 238| 135| 6121| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 2.0|
| 2785| 155| 18| 242| 118| 3090| 238| 238| 122| 6211| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 2.0|
| 2595| 45| 2| 153| -1| 391| 220| 234| 150| 6172| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2579| 132| 6| 300| -15| 67| 230| 237| 140| 6031| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 2.0|
| 2606| 45| 7| 270| 5| 633| 222| 225| 138| 6256| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2605| 49| 4| 234| 7| 573| 222| 230| 144| 6228| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2617| 45| 9| 240| 56| 666| 223| 221| 133| 6244| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2612| 59| 10| 247| 11| 636| 228| 219| 124| 6230| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2612| 201| 4| 180| 51| 735| 218| 243| 161| 6222| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2886| 151| 11| 371| 26| 5253| 234| 240| 136| 4051| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 2.0|
| 2742| 134| 22| 150| 69| 3215| 248| 224| 92| 6091| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 2.0|
| 2609| 214| 7| 150| 46| 771| 213| 247| 170| 6211| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2503| 157| 4| 67| 4| 674| 224| 240| 151| 5600| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2495| 51| 7| 42| 2| 752| 224| 225| 137| 5576| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2610| 259| 1| 120| -1| 607| 216| 239| 161| 6096| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2517| 72| 7| 85| 6| 595| 228| 227| 133| 5607| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2504| 0| 4| 95| 5| 691| 214| 232| 156| 5572| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
| 2503| 38| 5| 85| 10| 741| 220| 228| 144| 5555| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 1| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 5.0|
+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+-----------------+-----------------+-----------------+-----------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------+
val Array(trainDF, testDF) = dataDF.randomSplit(Array(0.75, 0.25))
trainDF.persist()
testDF.persist()
val inputCols = trainDF.columns.filter(_ != "Cover_Type")
val assembler = new VectorAssembler()
.setInputCols(inputCols)
.setOutputCol("featureVector")
val assemblerTrainDF = assembler.transform(trainDF).persist()
assemblerTrainDF.select("featureVector").show(false)
val assemblerTestDF = new VectorAssembler()
.setInputCols(inputCols)
.setOutputCol("featureVector")
.transform(testDF)
+-----------------------------------------------------------------------------------------------------+
|featureVector |
+-----------------------------------------------------------------------------------------------------+
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1863.0,37.0,17.0,120.0,18.0,90.0,217.0,202.0,115.0,769.0,1.0,1.0]) |
|(54,[0,1,2,5,6,7,8,9,13,18],[1874.0,18.0,14.0,90.0,208.0,209.0,135.0,793.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1879.0,28.0,19.0,30.0,12.0,95.0,209.0,196.0,117.0,778.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1888.0,33.0,22.0,150.0,46.0,108.0,209.0,185.0,103.0,735.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1889.0,28.0,22.0,150.0,23.0,120.0,205.0,185.0,108.0,759.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1889.0,353.0,30.0,95.0,39.0,67.0,153.0,172.0,146.0,600.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1896.0,337.0,12.0,30.0,6.0,175.0,195.0,224.0,168.0,732.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1898.0,34.0,23.0,175.0,56.0,134.0,210.0,184.0,99.0,765.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1899.0,355.0,22.0,153.0,43.0,124.0,178.0,195.0,151.0,819.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1901.0,311.0,9.0,30.0,2.0,190.0,195.0,234.0,179.0,726.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,16],[1903.0,67.0,16.0,108.0,36.0,120.0,234.0,207.0,100.0,969.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1905.0,19.0,27.0,134.0,58.0,120.0,188.0,171.0,108.0,636.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1905.0,33.0,27.0,90.0,46.0,150.0,204.0,171.0,89.0,725.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,16],[1905.0,77.0,21.0,90.0,38.0,120.0,241.0,196.0,75.0,1025.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1906.0,356.0,20.0,150.0,55.0,120.0,184.0,201.0,151.0,726.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1908.0,323.0,32.0,150.0,52.0,120.0,125.0,190.0,196.0,765.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,15],[1916.0,24.0,25.0,212.0,74.0,175.0,197.0,177.0,105.0,789.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,18],[1916.0,320.0,24.0,190.0,60.0,162.0,151.0,210.0,195.0,832.0,1.0,1.0])|
|(54,[0,1,2,3,4,5,6,7,8,9,13,23],[1918.0,321.0,28.0,42.0,17.0,85.0,139.0,201.0,196.0,402.0,1.0,1.0]) |
|(54,[0,1,2,3,4,5,6,7,8,9,13,14],[1919.0,30.0,22.0,67.0,9.0,256.0,208.0,188.0,107.0,661.0,1.0,1.0]) |
+-----------------------------------------------------------------------------------------------------+
2. 训练决策树模型
- 构建DecisionTreeClassifier模型,开始训练
// 构建模型
val classifier = new DecisionTreeClassifier()
.setSeed(Random.nextLong())
.setLabelCol("Cover_Type")
.setFeaturesCol("featureVector")
.setPredictionCol("prediction")
// 训练模型
val model = classifier.fit(assemblerTrainDF)
// 决策模型
println(model.toDebugString)
// 不同特征的信息增益,降序
model.featureImportances
.toArray
.zip(inputCols)
.sorted.reverse
.foreach(println)
DecisionTreeClassificationModel (uid=dtc_b7ddf2a70cb5) of depth 5 with 63 nodes
If (feature 0 <= 3052.0)
If (feature 0 <= 2558.0)
If (feature 10 <= 0.0)
If (feature 0 <= 2440.0)
If (feature 3 <= 0.0)
Predict: 4.0
Else (feature 3 > 0.0)
Predict: 3.0
Else (feature 0 > 2440.0)
If (feature 17 <= 0.0)
Predict: 3.0
Else (feature 17 > 0.0)
Predict: 3.0
……
(0.7792536945752957,Elevation)
(0.03867758671456936,Horizontal_Distance_To_Hydrology)
(0.032035474824597066,Wilderness_Area_0)
(0.030258022977407074,Soil_Type_3)
(0.030002164397023114,Hillshade_Noon)
(0.027754291761557144,Soil_Type_31)
(0.023639745113770847,Soil_Type_1)
(0.010979405745834852,Wilderness_Area_2)
(0.010136754139592311,Soil_Type_28)
(0.006542158483011739,Soil_Type_22)
……
(0.0,Soil_Type_0)
(0.0,Slope)
(0.0,Hillshade_9am)
(0.0,Aspect)
3. 预测森林植被
val predictionDF = model.transform(assemblerTestDF)
predictionDF.persist()
predictionDF.select("Cover_Type", "prediction", "probability")
.show(false)
+----------+----------+-------------------------------------------------------------------------------------------------+
|Cover_Type|prediction|probability |
+----------+----------+-------------------------------------------------------------------------------------------------+
|6.0 |3.0 |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
|6.0 |3.0 |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
|6.0 |3.0 |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
|6.0 |3.0 |[0.0,0.0,0.03027963630125236,0.6354005832904444,0.052024360953851434,0.0,0.28229541945445186,0.0]|
……
+----------+----------+-------------------------------------------------------------------------------------------------+
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("Cover_Type")
.setPredictionCol("prediction")
val accuracy = evaluator.setMetricName("accuracy").evaluate(predictionDF)
val f1 = evaluator.setMetricName("f1").evaluate(predictionDF)
println(s"accuracy = $accuracy, f1 = $f1")
accuracy = 0.6986190873428979, f1 = 0.6820440997673965
4. 利用网格搜索与交叉验证API
val pipeline = new Pipeline().setStages(Array(assembler, classifier))
val paramGrid = new ParamGridBuilder()
.addGrid(classifier.impurity, Seq("gini", "entropy"))
.addGrid(classifier.maxDepth, Seq(1, 20))
.addGrid(classifier.maxBins, Seq(40, 300))
.addGrid(classifier.minInfoGain, Seq(0.0, 0.05))
.build()
val multiclassEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("Cover_Type")
.setPredictionCol("prediction")
.setMetricName("accuracy")
// 构建模型
val validator = new TrainValidationSplit()
.setSeed(Random.nextLong())
.setEstimator(pipeline)
.setEvaluator(multiclassEvaluator)
.setEstimatorParamMaps(paramGrid)
.setTrainRatio(0.8)
// 训练模型
val validatorModel = validator.fit(trainDF)
validatorModel.validationMetrics
.zip(validatorModel.getEstimatorParamMaps)
.sortBy(-_._1)
.foreach { case (metric, params) =>
println("-----------------------------------------")
println(metric)
println(params)
}
val bestModel = validatorModel.bestModel
println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap())
{
dtc_b7ddf2a70cb5-cacheNodeIds: false,
dtc_b7ddf2a70cb5-checkpointInterval: 10,
dtc_b7ddf2a70cb5-featuresCol: featureVector,
dtc_b7ddf2a70cb5-impurity: entropy,
dtc_b7ddf2a70cb5-labelCol: Cover_Type,
dtc_b7ddf2a70cb5-maxBins: 40,
dtc_b7ddf2a70cb5-maxDepth: 20,
dtc_b7ddf2a70cb5-maxMemoryInMB: 256,
dtc_b7ddf2a70cb5-minInfoGain: 0.0,
dtc_b7ddf2a70cb5-minInstancesPerNode: 1,
dtc_b7ddf2a70cb5-predictionCol: prediction,
dtc_b7ddf2a70cb5-probabilityCol: probability,
dtc_b7ddf2a70cb5-rawPredictionCol: rawPrediction,
dtc_b7ddf2a70cb5-seed: -6398219726571299260
}
5. 随机森林模型
val classifier = new RandomForestClassifier()
.setSeed(Random.nextLong())
.setLabelCol("Cover_Type")
.setFeaturesCol("featureVector")
.setPredictionCol("prediction")
.setNumTrees(100)
6. 完整代码
import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{DataFrame, SparkSession}
import scala.util.Random
/**
* 第四章 - 决策树 - 预测森林植被
*
* @author ALion
*/
object RunRDF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("Demo").setMaster("local[4]")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
org.apache.log4j.Logger.getRootLogger.setLevel(
org.apache.log4j.Level.toLevel("WARN")
)
import spark.implicits._
// 1.准备数据
val dataDF = loadData(spark)
dataDF.show()
// 2. 拆分数据集
val Array(trainDF, testDF) = dataDF.randomSplit(Array(0.75, 0.25))
trainDF.persist()
testDF.persist()
// 3. 预处理
val inputCols = trainDF.columns.filter(_ != "Cover_Type")
val assembler = new VectorAssembler()
.setInputCols(inputCols)
.setOutputCol("featureVector")
// 训练集
val assemblerTrainDF = assembler.transform(trainDF).persist()
assemblerTrainDF.select("featureVector").show(false)
// 测试集
val assemblerTestDF = new VectorAssembler()
.setInputCols(inputCols)
.setOutputCol("featureVector")
.transform(testDF)
// 4. 构建决策树模型
// val classifier = new DecisionTreeClassifier()
// .setSeed(Random.nextLong())
// .setLabelCol("Cover_Type")
// .setFeaturesCol("featureVector")
// .setPredictionCol("prediction")
// 使用随机森林模型替换前面的决策树,提高准确率
val classifier = new RandomForestClassifier()
.setSeed(Random.nextLong())
.setLabelCol("Cover_Type")
.setFeaturesCol("featureVector")
.setPredictionCol("prediction")
.setNumTrees(100)
// 训练数据
val model = classifier.fit(assemblerTrainDF)
println(model.toDebugString) // 打印决策模型
// 打印不同特征的信息增益
model.featureImportances
.toArray
.zip(inputCols)
.sorted.reverse
.foreach(println)
// 5. 预测植被
val predictionDF = model.transform(assemblerTestDF)
predictionDF.persist()
predictionDF.select("Cover_Type", "prediction", "probability")
.show(false)
// 评分
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("Cover_Type")
.setPredictionCol("prediction")
val accuracy = evaluator.setMetricName("accuracy").evaluate(predictionDF)
val f1 = evaluator.setMetricName("f1").evaluate(predictionDF)
println(s"accuracy = $accuracy, f1 = $f1")
// 计算混淆矩阵
// 方法1
val predictionRDD = predictionDF
.select("prediction", "Cover_Type")
.as[(Double, Double)]
.rdd
val multiclassMetrics = new MulticlassMetrics(predictionRDD)
println(multiclassMetrics.confusionMatrix)
// 方法2
val confusionMatrix = predictionDF
.groupBy("Cover_Type")
.pivot("prediction", 1 to 7)
.count()
.na.fill(0.0)
.orderBy("Cover_Type")
confusionMatrix.show()
// 6. 网格搜索+交叉验证
// 构建管道模型
val pipeline = new Pipeline().setStages(Array(assembler, classifier))
// 构建网格参数
val paramGrid = new ParamGridBuilder()
.addGrid(classifier.impurity, Seq("gini", "entropy"))
.addGrid(classifier.maxDepth, Seq(1, 20))
.addGrid(classifier.maxBins, Seq(40, 300))
.addGrid(classifier.minInfoGain, Seq(0.0, 0.05))
.build()
// 构建分类模型的评估器
val multiclassEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("Cover_Type")
.setPredictionCol("prediction")
.setMetricName("accuracy")
// 开始网格搜索+交叉验证
val validator = new TrainValidationSplit()
.setSeed(Random.nextLong())
.setEstimator(pipeline)
.setEvaluator(multiclassEvaluator)
.setEstimatorParamMaps(paramGrid)
.setTrainRatio(0.8)
val validatorModel = validator.fit(trainDF)
// 获取训练结果的最佳模型,最佳参数
val bestModel = validatorModel.bestModel
println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap())
// 查看所有参数组合
validatorModel.validationMetrics
.zip(validatorModel.getEstimatorParamMaps)
.sortBy(-_._1)
.foreach { case (metric, params) =>
println("-----------------------------------------")
println(metric)
println(params)
}
spark.stop()
}
/**
* 加载原始数据
* @param spark SparkSession
* @return DataFrame
*/
def loadData(spark: SparkSession): DataFrame = {
import spark.implicits._
val dataWithoutHeaderDF = spark.read
.option("inferSchema", true)
.option("header", false)
.csv("E:/Data/saa/Chapter4_covtype/covtype.data")
// 重新定义字段名
val colNames = Seq(
"Elevation", "Aspect", "Slope",
"Horizontal_Distance_To_Hydrology", "Vertical_Distance_To_Hydrology",
"Horizontal_Distance_To_Roadways",
"Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
"Horizontal_Distance_To_Fire_Points") ++
(0 until 4).map(i => s"Wilderness_Area_$i") ++
(0 until 40).map(i => s"Soil_Type_$i") ++
Seq("Cover_Type")
dataWithoutHeaderDF.toDF(colNames: _*)
.withColumn("Cover_Type", $"Cover_Type".cast("double"))
}
}