Spark中基于神经网络的MLPC(多层感知器分类器)的使用

本文首发于我的个人博客QIMING.INFO,转载请带上链接及署名。

MLPC(Multilayer Perceptron Classifier),多层感知器分类器,是一种基于前馈人工神经网络(ANN)的分类器。Spark中目前仅支持此种与神经网络有关的算法,在org.apache.spark.ml中(并非mllib)。本文通过代码来演示用Spark运行MLPC的一个小例子。

算法简介

多层感知器是一种多层的前馈神经网络模型。

所谓前馈型神经网络,指其从输入层开始只接收前一层的输入,并把计算结果输出到后一层,并不会给前一层有所反馈,整个过程可以使用有向无环图来表示。该类型的神经网络由三层组成,分别是输入层(Input Layer),一个或多个隐层(Hidden Layer),输出层(Output Layer),如图所示:

MLPC采用了BP(反向传播,Back Propagation) 算法,BP算法的学习目的是对网络的连接权值进行调整,使得调整后的网络对任一输入都能得到所期望的输出。BP 算法名称里的反向传播指的是该算法在训练网络的过程中逐层反向传递误差,逐一修改神经元间的连接权值,以使网络对输入信息经过计算后所得到的输出能达到期望的误差。

Spark的多层感知器隐层神经元使用sigmoid函数作为激活函数,输出层使用的是softmax函数。

MLPC可调的几个重要参数:

  • featuresCol:输入数据 DataFrame 中指标特征列的名称。
  • labelCol:输入数据 DataFrame 中标签列的名称。
  • layers:这个参数是一个整型数组类型,第一个元素需要和特征向量的维度相等,最后一个元素需要训练数据的标签数相等,如 2 分类问题就写 2。中间的元素有多少个就代表神经网络有多少个隐层,元素的取值代表了该层的神经元的个数。例如val layers = (5,6,5,2)。
  • maxIter:优化算法求解的最大迭代次数。默认值是 100。
  • predictionCol:预测结果的列名称。

运行步骤

数据说明

MLPC对数据源有严格要求,只能是以下两种:

  • DataFrame
    使用DataFrame作为数据源时必须指定DataFrame中的标签列和特征列;
  • LIBSVM格式文本文件
    数据格式为:标签 特征ID:特征值 特征ID:特征值……

本例中采用了LIBSVM格式文本文件,数据如下:

[xuqm@cu01 ML_Data]$ cat input/sample_multiclass_classification_data.txt 
1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333 
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667 
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333 
1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667 
0 1:0.166667 2:-0.416667 3:0.457627 4:0.5 
……
……
……
2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667 
0 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333 
1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667 
1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667 
1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75 

代码及说明


import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

import org.apache.spark.sql.SparkSession

object MLPCTest {

  def main(args: Array[String]): Unit = {
    // 构建spark对象
    val spark = SparkSession.builder.appName("MLPCTest").getOrCreate()

    // 读取以LIBSVM格式存储的数据
    val data = spark.read.format("libsvm").load("file:///home/xuqm/ML_Data/input/sample_multiclass_classification_data.txt")

    // 拆分成训练集和测试集
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L)
    val train = splits(0)
    val test = splits(1)

    // 指定神经网络的图层:
    // 输入层4个结点(即4个特征);两个隐藏层,隐藏结点数分别为5和4;输出层3个结点(即分为3类)
    val layers = Array[Int](4, 5, 4, 3)

    // 建立MLPC训练器并设置参数
    val trainer = new MultilayerPerceptronClassifier().
      setLayers(layers).
      setBlockSize(128).
      setSeed(1234L).
      setMaxIter(100)

    // 训练模型
    val model = trainer.fit(train)

    // 用训练好的模型预测测试集的结果
    val result = model.transform(test)
    val predictionAndLabels = result.select("prediction", "label")

    // 计算误差并输出
    val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
    println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels))

    // 输出结果
    result.show(60,false)
  }
}

结果展示

// 计算误差并输出
Test set accuracy = 0.9019607843137255

// 输出结果
result.show(60,false)
+-----+---------------------------------------------------------+----------+
|label|features                                                 |prediction|
+-----+---------------------------------------------------------+----------+
|0.0  |(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333])    |2.0       |
|0.0  |(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333])    |0.0       |
|0.0  |(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333])    |0.0       |
|0.0  |(4,[0,1,2,3],[-0.0555556,-0.833333,0.355932,0.166667])   |2.0       |
|0.0  |(4,[0,1,2,3],[-0.0555556,-0.166667,0.288136,0.416667])   |2.0       |
|0.0  |(4,[0,1,2,3],[-1.32455E-7,-0.166667,0.322034,0.416667])  |2.0       |
|0.0  |(4,[0,1,2,3],[0.111111,-0.583333,0.355932,0.5])          |0.0       |
|0.0  |(4,[0,1,2,3],[0.222222,-0.166667,0.627119,0.75])         |0.0       |
|0.0  |(4,[0,1,2,3],[0.333333,-0.583333,0.627119,0.416667])     |0.0       |
|0.0  |(4,[0,1,2,3],[0.333333,-0.166667,0.423729,0.833333])     |0.0       |
|0.0  |(4,[0,1,2,3],[0.388889,-0.166667,0.525424,0.666667])     |0.0       |
|0.0  |(4,[0,1,2,3],[0.444444,-0.0833334,0.38983,0.833333])     |0.0       |
|0.0  |(4,[0,1,2,3],[0.555555,-0.166667,0.661017,0.666667])     |0.0       |
|0.0  |(4,[0,1,2,3],[0.722222,-0.333333,0.728813,0.5])          |0.0       |
|0.0  |(4,[0,1,2,3],[0.888889,-0.333333,0.932203,0.583333])     |0.0       |
|0.0  |(4,[0,1,2,3],[1.0,0.5,0.830508,0.583333])                |0.0       |
|0.0  |(4,[0,2,3],[0.166667,0.457627,0.833333])                 |0.0       |
|0.0  |(4,[0,2,3],[0.388889,0.661017,0.833333])                 |0.0       |
|1.0  |(4,[0,1,2,3],[-0.944444,-0.166667,-0.898305,-0.916667])  |1.0       |
|1.0  |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-0.833333])  |1.0       |
|1.0  |(4,[0,1,2,3],[-0.666667,-0.166667,-0.864407,-0.916667])  |1.0       |
|1.0  |(4,[0,1,2,3],[-0.666667,-0.0833334,-0.830508,-1.0])      |1.0       |
|1.0  |(4,[0,1,2,3],[-0.611111,0.166667,-0.79661,-0.75])        |1.0       |
|1.0  |(4,[0,1,2,3],[-0.555556,0.166667,-0.830508,-0.916667])   |1.0       |
|1.0  |(4,[0,1,2,3],[-0.555556,0.5,-0.830508,-0.833333])        |1.0       |
|1.0  |(4,[0,1,2,3],[-0.555556,0.5,-0.79661,-0.916667])         |1.0       |
|1.0  |(4,[0,1,2,3],[-0.5,0.166667,-0.864407,-0.916667])        |1.0       |
|1.0  |(4,[0,1,2,3],[-0.5,0.75,-0.830508,-1.0])                 |1.0       |
|1.0  |(4,[0,1,2,3],[-0.388889,0.166667,-0.830508,-0.75])       |1.0       |
|1.0  |(4,[0,1,2,3],[-0.388889,0.166667,-0.762712,-0.916667])   |1.0       |
|1.0  |(4,[0,1,2,3],[-0.388889,0.583333,-0.898305,-0.75])       |1.0       |
|1.0  |(4,[0,1,2,3],[-0.388889,0.583333,-0.762712,-0.75])       |1.0       |
|1.0  |(4,[0,1,2,3],[-0.333333,0.25,-0.898305,-0.916667])       |1.0       |
|1.0  |(4,[0,1,2,3],[-0.166667,0.666667,-0.932203,-0.916667])   |1.0       |
|1.0  |(4,[0,2,3],[-0.833333,-0.864407,-0.916667])              |1.0       |
|1.0  |(4,[0,2,3],[-0.777778,-0.898305,-0.916667])              |1.0       |
|2.0  |(4,[0,1,2,3],[-0.611111,-1.0,-0.152542,-0.25])           |2.0       |
|2.0  |(4,[0,1,2,3],[-0.555556,-0.583333,-0.322034,-0.166667])  |2.0       |
|2.0  |(4,[0,1,2,3],[-0.388889,-0.166667,0.186441,0.166667])    |2.0       |
|2.0  |(4,[0,1,2,3],[-0.333333,-0.666667,-0.0847458,-0.25])     |2.0       |
|2.0  |(4,[0,1,2,3],[-0.333333,-0.666667,-0.0508475,-0.166667]) |2.0       |
|2.0  |(4,[0,1,2,3],[-0.277778,-0.166667,0.186441,0.166667])    |2.0       |
|2.0  |(4,[0,1,2,3],[-0.222222,-0.5,-0.152542,-0.25])           |2.0       |
|2.0  |(4,[0,1,2,3],[-0.222222,-0.333333,0.0508474,-4.03573E-8])|2.0       |
|2.0  |(4,[0,1,2,3],[-0.111111,-0.166667,0.0847457,0.166667])   |2.0       |
|2.0  |(4,[0,1,2,3],[-0.0555556,-0.25,0.186441,0.166667])       |2.0       |
|2.0  |(4,[0,1,2,3],[-1.32455E-7,-0.25,0.254237,0.0833333])     |2.0       |
|2.0  |(4,[0,1,2,3],[0.0555554,-0.833333,0.186441,0.166667])    |2.0       |
|2.0  |(4,[0,1,2,3],[0.0555554,-0.25,0.118644,-4.03573E-8])     |2.0       |
|2.0  |(4,[0,1,2,3],[0.111111,0.0833333,0.254237,0.25])         |2.0       |
|2.0  |(4,[0,1,2,3],[0.333333,-0.166667,0.355932,0.333333])     |0.0       |
+-----+---------------------------------------------------------+----------+

你可能感兴趣的:(大数据)