SparkMlib 之随机森林及其案例

文章目录

    • 什么是随机森林?
    • 随机森林的优缺点
    • 随机森林示例——鸢尾花分类

什么是随机森林?

随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。

常应用于以下类型的场景:

  1. 预测用户贷款是否能够按时还款;
  2. 预测用户是否会购买某件商品等等

官网:分类和回归

随机森林的优缺点

优点:

  1. 可以处理高纬度的数据;

  2. 训练之前不需要特意的做特征选择;

  3. 建立很多树,预防了过拟合风险;

缺点:

  1. 计算量相对于决策树很大,性能开销很大。

  2. 可能会导致有些数据集没有训练到,但这种几率很小。

  3. 分裂的时候,偏向于选择取值较多的特征。

随机森林示例——鸢尾花分类

数据集下载:

链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 

提取码:
lz3l

数据集介绍:

iris.scale.txtlibsvm 格式的鸢尾花数据集,共有五个字段。第一个为标签字段,后四个为特征字段。

SparkMlib 之随机森林及其案例_第1张图片

libsvm 格式参考:机器学习:libsvm数据格式

将数据集中的随机百分之70作为训练集,剩余的作为测试集。

使用 SparkSQL 的方式读取 libsvm 格式的文件会自动生成 labelfeatures 结构的数据,如下所示:

val data: DataFrame = spark.read.format("libsvm").load("iris.scale.txt")

data.show()
SparkMlib 之随机森林及其案例_第2张图片

需求实现:

import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object Iris {

    def main(args: Array[String]): Unit = {

        val spark: SparkSession = SparkSession.builder().appName("Iris").master("local[*]").getOrCreate()

        // 加载 libsvm 格式文件的数据
        val data: DataFrame = spark.read.format("libsvm").load("C:\\Users\\Administrator\\Desktop\\iris.scale.txt")

        data.show()

        // 1.构建标签列转换对象
        val labelIndexer: StringIndexerModel = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data)

        // 2.构建特征列转换对象,设置特征列数量
        val featureIndexer: VectorIndexerModel = new VectorIndexer()
                .setInputCol("features")
                .setOutputCol("indexedFeatures")
                .setMaxCategories(4)
                .fit(data)

        // 3.将随机百分之70作为训练集,其余为测试集
        val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

        // 4.创建随机森林对象,设置标签列与特征列以及决策树的个数
        val rf: RandomForestClassifier = new RandomForestClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("indexedFeatures")
                .setNumTrees(10)

        // 5.设置预测列标签
        val labelConverter: IndexToString = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labelsArray(0))

        // 6.管道组装
        val pipeline: Pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

        // 7.模型训练
        val model: PipelineModel = pipeline.fit(trainingData)

        // 8.模型预测
        val predictions: DataFrame = model.transform(testData)

        // 9.模型评估
        predictions.select("predictedLabel", "label", "features").show()

        // 10.创建错误率的计算对象
        val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy")

        // 11.计算错误率
        val accuracy: Double = evaluator.evaluate(predictions)
        println(s"Test Error = ${(1.0 - accuracy)}")

        // 12.打印随机森林模型
        val rfModel: RandomForestClassificationModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
        println(s"Learned classification forest model:\n ${rfModel.toDebugString}")

    }

}

你可能感兴趣的:(随机森林,决策树,大数据,mllib)