Spark随机森林之多分类模型

Spark随机森林之多分类模型

关于随机森林

随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。

其中,决策树相关传送门 ,这里不再详述。

官方实例

以下是官方给出的一个demo

import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
// 加载数据
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据随机分配为两份,一份用于训练,一份用于测试
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 随机森林训练参数设置
//分类数
val numClasses = 2
// categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量
val categoricalFeaturesInfo = Map[Int, Int]()
//树的个数
val numTrees = 3 
//特征子集采样策略,auto 表示算法自主选取
val featureSubsetStrategy = "auto" 
//纯度计算
val impurity = "gini"
//树的最大层次
val maxDepth = 4
//特征最大装箱数
val maxBins = 32
//训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
 numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

// 测试数据评价训练好的分类器并计算错误率
val labelAndPreds = testData.map { point =>
 val prediction = model.predict(point.features)
 (point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification forest model:\n" + model.toDebugString)

// 将训练后的随机森林模型持久化
model.save(sc, "myModelPath")
//加载随机森林模型到内存
val sameModel = RandomForestModel.load(sc, "myModelPath")

在这个demo中,numClasses=2,是一个二分类问题,numTrees=3,也就是说每个森林有三棵决策树,每一特征向量经过这三棵树进行分类,最后综合来看0类和1类在三棵树预测出的标签中的占比,如果2棵树或3棵树预测为0类,则为0类,否则为1类。

这给了我们一个进行多分类的小技巧:假设某一事物有多个类型A、B、C、D类,不仅如此,其中还有一些是复合类型,比如A类和B类的复合类型,那么这时候用随机森林如何进行分类判别?

那么这时就可以定义较多数量的决策树,比如我定义numTrees=10,通过

val model: RandomForestModel=RandomForest.trainClassifier(
      trainingData,numClasses,categoricalFeaturesInfo,numTrees,
      featureSubsetStrategy,impurity, maxDepth, maxBins)

val tr: Array[DecisionTreeModel] =model.trees   

可以获取随机森林中的每棵树,即一个DecisionTreeModel数组,每个DecisionTreeModel都有predict方法,这时可以获取每棵树对某一特征向量的分类判别,对10棵树的结果进行统计分析。

举例来说,对于某一特征向量,10棵树中,有8棵判别为A类,1棵判别为B类,1棵判别为C类,则这个特征所属的载体有80%的“概率”属于A类,当然也可以设定一个阙值,超过这个阙值直接判为该类别。对于另一特征向量,5棵树判别为A类,5棵树判别为D类,这时我们就可以认为它是属于A类和D类的复合类型。

你可能感兴趣的:(大数据)