本文首发于我的个人博客QIMING.INFO,转载请带上链接及署名。
MLPC(Multilayer Perceptron Classifier)
,多层感知器分类器,是一种基于前馈人工神经网络(ANN)的分类器。Spark中目前仅支持此种与神经网络有关的算法,在org.apache.spark.ml
中(并非mllib)。本文通过代码来演示用Spark运行MLPC的一个小例子。
多层感知器是一种多层的前馈神经网络模型。
所谓前馈型神经网络,指其从输入层开始只接收前一层的输入,并把计算结果输出到后一层,并不会给前一层有所反馈,整个过程可以使用有向无环图来表示。该类型的神经网络由三层组成,分别是输入层(Input Layer)
,一个或多个隐层(Hidden Layer)
,输出层(Output Layer)
,如图所示:
MLPC采用了BP(反向传播,Back Propagation
) 算法,BP算法的学习目的是对网络的连接权值进行调整,使得调整后的网络对任一输入都能得到所期望的输出。BP 算法名称里的反向传播指的是该算法在训练网络的过程中逐层反向传递误差,逐一修改神经元间的连接权值,以使网络对输入信息经过计算后所得到的输出能达到期望的误差。
Spark的多层感知器隐层神经元使用sigmoid
函数作为激活函数,输出层使用的是softmax
函数。
MLPC可调的几个重要参数:
MLPC对数据源有严格要求,只能是以下两种:
本例中采用了LIBSVM格式文本文件,数据如下:
[xuqm@cu01 ML_Data]$ cat input/sample_multiclass_classification_data.txt
1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333
1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667
0 1:0.166667 2:-0.416667 3:0.457627 4:0.5
……
……
……
2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667
0 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333
1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667
1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667
1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql.SparkSession
object MLPCTest {
def main(args: Array[String]): Unit = {
// 构建spark对象
val spark = SparkSession.builder.appName("MLPCTest").getOrCreate()
// 读取以LIBSVM格式存储的数据
val data = spark.read.format("libsvm").load("file:///home/xuqm/ML_Data/input/sample_multiclass_classification_data.txt")
// 拆分成训练集和测试集
val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L)
val train = splits(0)
val test = splits(1)
// 指定神经网络的图层:
// 输入层4个结点(即4个特征);两个隐藏层,隐藏结点数分别为5和4;输出层3个结点(即分为3类)
val layers = Array[Int](4, 5, 4, 3)
// 建立MLPC训练器并设置参数
val trainer = new MultilayerPerceptronClassifier().
setLayers(layers).
setBlockSize(128).
setSeed(1234L).
setMaxIter(100)
// 训练模型
val model = trainer.fit(train)
// 用训练好的模型预测测试集的结果
val result = model.transform(test)
val predictionAndLabels = result.select("prediction", "label")
// 计算误差并输出
val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels))
// 输出结果
result.show(60,false)
}
}
// 计算误差并输出
Test set accuracy = 0.9019607843137255
// 输出结果
result.show(60,false)
+-----+---------------------------------------------------------+----------+
|label|features |prediction|
+-----+---------------------------------------------------------+----------+
|0.0 |(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]) |2.0 |
|0.0 |(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333]) |0.0 |
|0.0 |(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]) |0.0 |
|0.0 |(4,[0,1,2,3],[-0.0555556,-0.833333,0.355932,0.166667]) |2.0 |
|0.0 |(4,[0,1,2,3],[-0.0555556,-0.166667,0.288136,0.416667]) |2.0 |
|0.0 |(4,[0,1,2,3],[-1.32455E-7,-0.166667,0.322034,0.416667]) |2.0 |
|0.0 |(4,[0,1,2,3],[0.111111,-0.583333,0.355932,0.5]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.222222,-0.166667,0.627119,0.75]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.333333,-0.583333,0.627119,0.416667]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.333333,-0.166667,0.423729,0.833333]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.388889,-0.166667,0.525424,0.666667]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.444444,-0.0833334,0.38983,0.833333]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.555555,-0.166667,0.661017,0.666667]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.722222,-0.333333,0.728813,0.5]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.888889,-0.333333,0.932203,0.583333]) |0.0 |
|0.0 |(4,[0,1,2,3],[1.0,0.5,0.830508,0.583333]) |0.0 |
|0.0 |(4,[0,2,3],[0.166667,0.457627,0.833333]) |0.0 |
|0.0 |(4,[0,2,3],[0.388889,0.661017,0.833333]) |0.0 |
|1.0 |(4,[0,1,2,3],[-0.944444,-0.166667,-0.898305,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-0.833333]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.666667,-0.166667,-0.864407,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.666667,-0.0833334,-0.830508,-1.0]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.611111,0.166667,-0.79661,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.555556,0.166667,-0.830508,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.555556,0.5,-0.830508,-0.833333]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.555556,0.5,-0.79661,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.5,0.166667,-0.864407,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.5,0.75,-0.830508,-1.0]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.166667,-0.830508,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.166667,-0.762712,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.583333,-0.898305,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.583333,-0.762712,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.333333,0.25,-0.898305,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.166667,0.666667,-0.932203,-0.916667]) |1.0 |
|1.0 |(4,[0,2,3],[-0.833333,-0.864407,-0.916667]) |1.0 |
|1.0 |(4,[0,2,3],[-0.777778,-0.898305,-0.916667]) |1.0 |
|2.0 |(4,[0,1,2,3],[-0.611111,-1.0,-0.152542,-0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.555556,-0.583333,-0.322034,-0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.388889,-0.166667,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.333333,-0.666667,-0.0847458,-0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.333333,-0.666667,-0.0508475,-0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.277778,-0.166667,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.222222,-0.5,-0.152542,-0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.222222,-0.333333,0.0508474,-4.03573E-8])|2.0 |
|2.0 |(4,[0,1,2,3],[-0.111111,-0.166667,0.0847457,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.0555556,-0.25,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-1.32455E-7,-0.25,0.254237,0.0833333]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.0555554,-0.833333,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.0555554,-0.25,0.118644,-4.03573E-8]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.111111,0.0833333,0.254237,0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.333333,-0.166667,0.355932,0.333333]) |0.0 |
+-----+---------------------------------------------------------+----------+