环境版本: ·Spark 2.0 ·Scala 2.11.8
在网上搜索Spark MLlib和Spark Streaming结合的例子几乎没有,我很疑惑,难道实现准实时预测有别的更合理的方式?望大佬在评论区指出。本篇博客思路很简单,使用Spark MLlib训练并保存模型,然后编写Spark Streaming程序读取并使用模型。需注意的是,在使用Spark MLlib之前我使用了python查看分析数据、清洗数据、特征工程、构造数据集、训练模型等等,且在本篇中直接使用了python构造的数据集。
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.rdd.RDD
/**
* 训练模型
* Created by drguo on 2020/5/20 11:34.
*/
object RandomForestM {
def main(args: Array[String]) {
val sparkConf = new SparkConf()
// 本地模式,* 自动检测cpu核心,占满
.setMaster("local[*]")
.setAppName("rf")
val sc = new SparkContext(sparkConf)
// 读取数据
val rawData = sc.textFile("hdfs://xx:8020/model/data/xx.csv")
val data = rawData.map { line =>
val values = line.split(",").map(_.toDouble)
// init返回除了最后一个元素的所有元素,作为特征向量
// Vectors.dense向量化,dense密集型
val feature = Vectors.dense(values.init)
val label = values.last
LabeledPoint(label, feature)
}
// 训练集、交叉验证集和测试集,各占80%,10%和10%
// 10%的交叉验证数据集的作用是确定在训练数据集上训练出来的模型的最好参数
// 测试数据集的作用是评估CV数据集的最好参数
val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))
trainData.cache()
cvData.cache()
testData.cache()
// 构建随机森林
val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", "gini", 4, 32)
val metrics = getMetrics(model, cvData)
// 混淆矩阵和模型精确率
println(metrics.confusionMatrix)
println(metrics.accuracy)
// 每个类别对应的精确率与召回率
(0 until 2).map(target => (metrics.precision(target), metrics.recall(target))).foreach(println)
// 保存模型
model.save(sc, "hdfs://xx:8020/model/xxModel")
}
/**
* @param model 随机森林模型
* @param data 用于交叉验证的数据集
**/
def getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
// 将交叉验证数据集的每个样本的特征向量交给模型预测,并和原本正确的目标特征组成一个tuple
val predictionsAndLables = data.map { d =>
(model.predict(d.features), d.label)
}
// 将结果交给MulticlassMetrics,其可以以不同的方式计算分配器预测的质量
new MulticlassMetrics(predictionsAndLables)
}
/**
* 在训练数据集上得到最好的参数组合
*
* @param trainData 训练数据集
* @param cvData 交叉验证数据集
**/
def getBestParam(trainData: RDD[LabeledPoint], cvData: RDD[LabeledPoint]): Unit = {
val evaluations = for (impurity <- Array("gini", "entropy");
depth <- Array(1, 20);
bins <- Array(10, 300)) yield {
val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", impurity, depth, bins)
val metrics = getMetrics(model, cvData)
((impurity, depth, bins), metrics.accuracy)
}
evaluations.sortBy(_._2).reverse.foreach(println)
}
}
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe
import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
import org.apache.spark.streaming.kafka010.{HasOffsetRanges, KafkaUtils, OffsetRange}
/**
* Created by drguo on 2020/5/20 11:34.
*/
object ModelTest {
private val brokers = "xx1:6667,xx2:6667,xx3:6667"
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf()
// 本地模式,* 自动检测cpu核心,占满
.setMaster("local[*]")
.setAppName("ModelTest")
sparkConf.set("spark.sql.warehouse.dir","file:///") // 本地
val sc = new SparkContext(sparkConf)
//读取模型
val rfModel = RandomForestModel.load(sc, "hdfs://xx:8020/model/xxModel")
val ssc = new StreamingContext(sc, Seconds(6))
val topics = Array("xx1", "xx2")
val kafkaParams = Map[String, Object](
"bootstrap.servers" -> brokers,
"key.deserializer" -> classOf[StringDeserializer],
"value.deserializer" -> classOf[StringDeserializer],
"group.id" -> "hqc",
"auto.offset.reset" -> "latest",
"enable.auto.commit" -> (false: java.lang.Boolean)
)
val messages: InputDStream[ConsumerRecord[String, String]] = KafkaUtils.createDirectStream[String, String](
ssc,
PreferConsistent,
Subscribe[String, String](topics, kafkaParams)
)
// 1个 partition 分区
messages.foreachRDD(rdd => {
val offsetRanges: Array[OffsetRange] = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
rdd.foreach((msg: ConsumerRecord[String, String]) => {
val o = offsetRanges(TaskContext.get.partitionId)
// println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}")
val topic: String = o.topic
topic match {
case "xx1" =>
val line = KxxDataClean(msg.value)
if (line != "") {
val values = line.split(",").map(_.toDouble)
val feature = Vectors.dense(values)
//进行预测
val preLabel = rfModel.predict(feature)
println(preLabel)
}
case "xx2" =>
}
})
})
ssc.start()
ssc.awaitTermination()
}
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xx</groupId>
<artifactId>spark-xx-model</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<spark.version>2.0.0</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark.version}</version>
<scope>compile</scope>
</dependency>
<!--spark streaming + kafka-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.11</artifactId>
<version>${spark.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka_2.11</artifactId>
<version>0.10.0.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka-0-10_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<!--mysql-->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.39</version>
</dependency>
<!--日志-->
<dependency>
<groupId>com.typesafe.scala-logging</groupId>
<artifactId>scala-logging_2.11</artifactId>
<version>3.7.2</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.3</version>
<configuration>
<classifier>dist</classifier>
<appendAssemblyId>true</appendAssemblyId>
<descriptorRefs>
<descriptor>jar-with-dependencies</descriptor>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>