代码实现分为读取hive数据、随机森林建模训练、数据预测
package com.inspur.mr.InspurMr.Classification
import java.io.File
import java.io.PrintWriter
import java.util.ArrayList
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.linalg.{ Vector, Vectors }
import com.inspur.mr.InspurMr.conf.RandomForestConf
import com.inspur.mr.InspurMr.Util.Quota
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.hadoop.fs.{FileStatus, FileSystem, FileUtil, Path}
import java.text.SimpleDateFormat
import java.util.Date
object RandomWithHive extends RandomForestConf {
def main(args: Array[String]): Unit = {
import hc.implicits._
// 从hive中获取数据
val database = paraproperties.getProperty("database")
val null_fill = paraproperties.getProperty("null_fill")
val eare_lon_left = paraproperties.getProperty("eare_lon_left")
val eare_lat_left = paraproperties.getProperty("eare_lat_left")
val eare_lon_right = paraproperties.getProperty("eare_lon_right")
val eare_lat_right = paraproperties.getProperty("eare_lat_right")
val grid_length = paraproperties.getProperty("grid_length")
val grid_num = paraproperties.getProperty("grid_num").toInt
val disgrid = grid_length.toDouble*0.000009
hc.sql(s"use $database")
val data1 = hc.sql(s"""select floor(($eare_lat_left-lat_uri)/$disgrid)*$grid_num+floor((long_uri-$eare_lon_left)/$disgrid) as llgridid,cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3 from dw_pods_mro_eutrancell_yyyymmdd where lat_uri<$eare_lat_left and lat_uri>$eare_lat_right and long_uri>$eare_lon_left and long_uri<$eare_lon_right and pow(long_uri-LON0,2)+pow(lat_uri-LAT0,2)<0.00002025 order by hour_id desc limit 30000000""".stripMargin)
// val pathpath = "file:///C:\\Users\\wangkai01\\Desktop\\data\\csvtest.csv"
val data = data1.na.fill(null_fill.toDouble).cache()
println(s"""select floor(($eare_lat_left-lat_uri)/$disgrid)*$grid_num+floor((long_uri-$eare_lon_left)/$disgrid) as llgridid,cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3 from dw_pods_mro_eutrancell_yyyymmdd where lat_uri<$eare_lat_left and lat_uri>$eare_lat_right and long_uri>$eare_lon_left and long_uri<$eare_lon_right and pow(long_uri-LON0,2)+pow(lat_uri-LAT0,2)<0.00002025 order by hour_id desc limit 30000000""".stripMargin)
println("run here1 !!!!!!!!")
// data.show()
// 特征
val featInd = List("cellid", "ltesctadv", "ltescaoa", "ltescphr", "ltescrip", "ltescsinrul", "ltescearfcn", "ltescpci", "LON0", "LAT0", "azimuth0", "coverflag0", "nettype0", "ltescrsrp", "ltescrsrq", "ltencrsrp1", "ltencrsrq1", "ltencearfcn1", "ltencpci1", "ltencrsrp2", "ltencrsrq2", "ltencearfcn2", "ltencpci2", "ltencrsrp3", "ltencrsrq3", "ltencearfcn3", "ltencpci3").map(data.columns.indexOf(_))
println(featInd)
// 标签label
val Label = data.columns.indexOf("llgridid")
val datause = data.map { x =>
val label = x(0).toString().toInt
val feature = Vectors.dense(featInd.map(x.getDouble(_)).toArray)
// println(feature)
LabeledPoint(label, feature)
}
println("run here2 !!!!!!!!")
//生成训练集和测试集
val splits = datause.randomSplit(Array(tarining_rate, test_rate))
val (trainingData, testData) = (splits(0), splits(1))
//查看训练样本的标签数,做为分类数目
// val numClasses = (datause.map { x => x.label }.max() + 1).toInt
val numClasses = class_num
//获取训练样本和测试样本的数量
val train_sample = trainingData.count()
val test_sample = testData.count()
println("run here3 !!!!!!!!")
//若存在上次训练文件则删除,并新建模型保存目录。
val path = new Path(model_out_path);
val hdfs = org.apache.hadoop.fs.FileSystem.get(
new java.net.URI(model_out_path), new org.apache.hadoop.conf.Configuration())
if (!hdfs.exists(path)){
hdfs.mkdirs(path)
}else{
hdfs.delete(path, true)
hdfs.mkdirs(path)
}
var bestscore=0.0
for (numTrees <- treeList; maxDepth <- depthList) {
val s = Strategy.defaultStrategy("Classification")
s.setMaxMemoryInMB(2048)
s.setNumClasses(numClasses)
s.setMaxDepth(maxDepth)
s.setMaxBins(maxBins)
val model = RandomForest.trainClassifier(trainingData, s, numTrees, featureSubsetStrategy, 10)
// 测试数据评价训练好的分类器并计算错误率
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val quota = Quota.calculate(labelAndPreds, testData)
val testErr = quota._1
// val testRecall = quota._3
// val f1_score = quota._4
println("Test Error = " + testErr)
// println("Learned classification forest model:\n" + model.toDebugString)
// hdfs.createNewFile(new Path(describe + s"result-$numTrees-$maxDepth-$testErr.txt"))
// val dirfile = new File(describe);
// if (!dirfile.isDirectory()) {
// dirfile.mkdirs()
// }
// val resultfile = new File(describe + s"result-$numTrees-$maxDepth-$testErr.txt")
// if(!resultfile.isFile()){
// val writer = new PrintWriter(resultfile)
// // writer.println("train pos count:" + pos_sample + "\n")
// // writer.println("train neg count:" + neg_sample + "\n")
// writer.println("train count:" + train_sample + "\n")
// writer.println("test count:" + test_sample + "\n")
// writer.println("Test Error = " + testErr + "\n")
// writer.println(model.toDebugString)
// writer.close()
// }
println(s"model-$numTrees-$maxDepth:"+(1-testErr))
println(model.toDebugString)
// 将训练后的随机森林模型持久化
val now: Date = new Date()
val dateFormat: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss")
val date = dateFormat.format(now)
val path = new Path(model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date");
// 该参数模型不存在时,则保存模型
val hdfs = org.apache.hadoop.fs.FileSystem.get(
new java.net.URI(model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date"), new org.apache.hadoop.conf.Configuration())
if (!hdfs.exists(path)){
model.save(sc, model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date")
}
if(1-testErr>=bestscore){
//该参数模型不存在时,则保存模型
val path = new Path(model_file);
val hdfs = org.apache.hadoop.fs.FileSystem.get(
new java.net.URI(model_file), new org.apache.hadoop.conf.Configuration())
if (hdfs.exists(path)) hdfs.delete(path, true)
model.save(sc, model_out_path + "model-RF-best")
bestscore = 1-testErr
}
}
sc.stop()
println("best score:"+bestscore)
println("run done !!!!!!!!")
}
}
2、随机森林预测的代码
package com.inspur.mr.InspurMr.Classification
import com.inspur.mr.InspurMr.conf.AppConf
import org.apache.spark.mllib.tree.model.RandomForestModel
import com.inspur.mr.InspurMr.Util.MLUtils
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
object RandomPredict extends AppConf {
case class TableMrPre(msisdn:String,imsi:String,imei:String,begintime:String,tac:String,eci:String,nettype0:String,long_uri:Double,lat_uri:Double)
def main(args: Array[String]): Unit = {
val database = paraproperties.getProperty("database")
val null_fill = paraproperties.getProperty("null_fill")
val eare_lon_left = paraproperties.getProperty("eare_lon_left").toDouble
val eare_lat_left = paraproperties.getProperty("eare_lat_left").toDouble
val eare_lon_right = paraproperties.getProperty("eare_lon_right")
val eare_lat_right = paraproperties.getProperty("eare_lat_right")
val grid_length = paraproperties.getProperty("grid_length")
val grid_num = paraproperties.getProperty("grid_num").toDouble
val disgrid = grid_length.toDouble*0.000009
val disgridhalf = grid_length.toDouble*0.000009/2
var HOUR_ID = args(0)
var MONTH_ID = HOUR_ID.substring(0,6)
var DAY_ID = HOUR_ID.substring(0,8)
val write_partition = "month_id="+MONTH_ID+","+"day_id="+DAY_ID+","+"hour_id="+HOUR_ID
val read_partition = "month_id="+MONTH_ID+" and "+"day_id="+DAY_ID+" and "+"hour_id="+HOUR_ID
conf.setAppName("family_test")
val pModlePath = postgprop.getProperty("model_file")
hc.sql(s"use $database")
val data = hc.sql(s"""select cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3,msisdn,imsi,imei,begintime,tac,eci,nettype0 from dw_pods_mro_eutrancell_pre_yyyymmdd where $read_partition""".stripMargin)
println(s"""select cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3,msisdn,imsi,imei,begintime,tac,eci,nettype0 from dw_pods_mro_eutrancell_pre_yyyymmdd where $read_partition""")
println("run here1 !!!!!!!!")
//data.show() //
val sameModel = RandomForestModel.load(sc, pModlePath)
println("run here2!!!!!")
val labelAndPreds = data.map { row =>
def isNull(xarr:Any):String = if (null==xarr) "-2" else xarr.toString()
val rowStr = isNull(row(0))+" "+isNull(row(1))+" "+isNull(row(2))+" "+isNull(row(3))+" "+isNull(row(4))+" "+isNull(row(5))+" "+isNull(row(6))+" "+isNull(row(7))+" "+isNull(row(8))+" "+isNull(row(9))+" "+isNull(row(10))+" "+isNull(row(11))+" "+isNull(row(12))+" "+isNull(row(13))+" "+isNull(row(14))+" "+isNull(row(15))+" "+isNull(row(16))+" "+isNull(row(17))+" "+isNull(row(18))+" "+isNull(row(19))+" "+isNull(row(20))+" "+isNull(row(21))+" "+isNull(row(22))+" "+isNull(row(23))+" "+isNull(row(24))+" "+isNull(row(25))+" "+isNull(row(26))
val prediction = sameModel.predict(Vectors.dense(rowStr.split(' ').map { _.toDouble }))
val glong = prediction%grid_num
val glat = prediction/grid_num
val lonPre=glong*disgrid+eare_lon_left+disgridhalf
val latPre=eare_lat_left-glat*disgrid-disgridhalf
TableMrPre(isNull(row(27)),isNull(row(28)),isNull(row(29)),isNull(row(30)),isNull(row(31)),isNull(row(32)),isNull(row(33)),lonPre,latPre)
}.cache
println("run here4!!!!!")
import hc.implicits._
val tabledf = labelAndPreds.toDF()
// tabledf.show(100)
tabledf.registerTempTable("TempTableMrPre")
hc.sql("insert OVERWRITE table dw_mr_mme_position_pre partition("+write_partition+") select * from TempTableMrPre")
hc.dropTempTable("TempTableMrPre")
sc.stop()
println("run done!!!!!")
}
}