根据每天的pm2.5数值分为优,良,轻度污染,中度污染等对这些级别进行预测
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._
/**
* Created by Dank on 2017/8/17.
*/
object logPredict {
Logger.getLogger("org").setLevel(Level.ERROR)
case class pm(No: Int, year: Int, month: Int, day: Int, hour: Int, pm: Double, DEWP: Int, TEMP: Double,PRES: Double, cbwd: String, Iws: Double, Is: Int, Ir: Int, levelNum: Double, levelStr: String)
def main(args: Array[String]) {
val root = this.getClass.getResource("/")
val conf = new SparkConf().setAppName("test").setMaster("local[*]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
var Str = ""
var Num = -1
val parsedRDD = sc.textFile(root + "pm.csv")
.map(Row => {
Row.split(",")
}) //0,1,5
.filter(!_ (0).equals("No"))
.filter(line => {
var rs = true
line.foreach(field => {
if (field.equals("") || field.toString.equals("NaN")) rs = false
})
rs
})
.map(p => {
if (p(5).toDouble < 50) {
Num = 0;
Str = "优"
}
else if (p(5).toDouble <= 100) {
Num = 1;
Str = "良"
}
else if (p(5).toDouble <= 150) {
Num = 2;
Str = "轻度污染"
}
else if (p(5).toDouble <= 200) {
Num = 3;
Str = "中度污染"
}
else if (p(5).toDouble <= 300) {
Num = 4;
Str = "重度污染"
}
else {
Num = 5;
Str = "严重污染"
}
pm(p(0).toInt, p(1).toInt, p(2).toInt, p(3).toInt, p(4).toInt, p(5).toDouble, p(6).toInt, p(7).toDouble,
p(8).toDouble, p(9).toString, p(10).toDouble, p(11).toInt, p(12).toInt, Num, Str.toString)
})
import sqlContext.implicits._
val pmDF = parsedRDD.toDF()
pmDF.show(5)
val labelIndexer = new StringIndexer()
.setInputCol("levelNum")
.setOutputCol("label")
.fit(pmDF)
val indexer = new StringIndexer().setInputCol("cbwd").setOutputCol("cbwd_")
val assembler = new VectorAssembler()
.setInputCols(Array("month", "day", "hour", "DEWP", "TEMP", "PRES", "cbwd_", "Iws", "Is", "Ir"))
.setOutputCol("features")
val Array(trainingData, testData) = pmDF.randomSplit(Array(0.8, 0.2))
testData.show(5)
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
val pipeline = new Pipeline().setStages(Array(labelIndexer, indexer, assembler, rf))
val model = pipeline.fit(trainingData)
val predictions = model.transform(testData)
predictions.show(10)
val coder: (Double => String) = (Num: Double) => {
if (Num == 0.0) "优"
else if (Num == 1.0) "良"
else if (Num == 2.0) "轻度污染"
else if (Num == 3.0) "中度污染"
else if (Num == 4.0) "重度污染"
else "严重污染"
}
val sqlfunc = udf(coder)
val predictions_ = predictions.withColumn("predictionStr", sqlfunc(col("prediction")))
predictions_.show(10)
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("precision")
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))
sc.stop
}
}