基于spark mllib 随机森林分类 代码记录后续完善

scala+spark+randomForests

代码实现分为读取hive数据、随机森林建模训练、数据预测

  1. 随机森林建模训练的代码主类(实现流程)
package com.inspur.mr.InspurMr.Classification
import java.io.File
import java.io.PrintWriter
import java.util.ArrayList
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.linalg.{ Vector, Vectors }
import com.inspur.mr.InspurMr.conf.RandomForestConf
import com.inspur.mr.InspurMr.Util.Quota
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.hadoop.fs.{FileStatus, FileSystem, FileUtil, Path}
import java.text.SimpleDateFormat
import java.util.Date

object RandomWithHive extends RandomForestConf {

  def main(args: Array[String]): Unit = {
    import hc.implicits._
    //    从hive中获取数据
    val database = paraproperties.getProperty("database")
    val null_fill = paraproperties.getProperty("null_fill")
    val eare_lon_left = paraproperties.getProperty("eare_lon_left")
    val eare_lat_left = paraproperties.getProperty("eare_lat_left")
    val eare_lon_right = paraproperties.getProperty("eare_lon_right")
    val eare_lat_right = paraproperties.getProperty("eare_lat_right")
    val grid_length = paraproperties.getProperty("grid_length")
    val grid_num = paraproperties.getProperty("grid_num").toInt
    val disgrid = grid_length.toDouble*0.000009

    hc.sql(s"use $database")
    val data1 = hc.sql(s"""select floor(($eare_lat_left-lat_uri)/$disgrid)*$grid_num+floor((long_uri-$eare_lon_left)/$disgrid) as llgridid,cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3 from dw_pods_mro_eutrancell_yyyymmdd where lat_uri<$eare_lat_left and lat_uri>$eare_lat_right and long_uri>$eare_lon_left and long_uri<$eare_lon_right and pow(long_uri-LON0,2)+pow(lat_uri-LAT0,2)<0.00002025 order by hour_id desc limit 30000000""".stripMargin)

    //    val pathpath = "file:///C:\\Users\\wangkai01\\Desktop\\data\\csvtest.csv"
    val data = data1.na.fill(null_fill.toDouble).cache()
    println(s"""select floor(($eare_lat_left-lat_uri)/$disgrid)*$grid_num+floor((long_uri-$eare_lon_left)/$disgrid) as llgridid,cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3 from dw_pods_mro_eutrancell_yyyymmdd where lat_uri<$eare_lat_left and lat_uri>$eare_lat_right and long_uri>$eare_lon_left and long_uri<$eare_lon_right and pow(long_uri-LON0,2)+pow(lat_uri-LAT0,2)<0.00002025 order by hour_id desc limit 30000000""".stripMargin)
    println("run here1 !!!!!!!!")
//    data.show()
    //    特征
    val featInd = List("cellid", "ltesctadv", "ltescaoa", "ltescphr", "ltescrip", "ltescsinrul", "ltescearfcn", "ltescpci", "LON0", "LAT0", "azimuth0", "coverflag0", "nettype0", "ltescrsrp", "ltescrsrq", "ltencrsrp1", "ltencrsrq1", "ltencearfcn1", "ltencpci1", "ltencrsrp2", "ltencrsrq2", "ltencearfcn2", "ltencpci2", "ltencrsrp3", "ltencrsrq3", "ltencearfcn3", "ltencpci3").map(data.columns.indexOf(_))
    println(featInd)
    //    标签label
    val Label = data.columns.indexOf("llgridid")
    val datause = data.map { x =>
      val label = x(0).toString().toInt
      val feature = Vectors.dense(featInd.map(x.getDouble(_)).toArray)
      //                   println(feature)
      LabeledPoint(label, feature)
    }
    println("run here2 !!!!!!!!")
    //生成训练集和测试集
    val splits = datause.randomSplit(Array(tarining_rate, test_rate))
    val (trainingData, testData) = (splits(0), splits(1))

    //查看训练样本的标签数,做为分类数目
//    val numClasses = (datause.map { x => x.label }.max() + 1).toInt
    val numClasses = class_num
    //获取训练样本和测试样本的数量
    val train_sample = trainingData.count()
    val test_sample = testData.count()
    println("run here3 !!!!!!!!")

    //若存在上次训练文件则删除,并新建模型保存目录。
    val path = new Path(model_out_path);    
    val hdfs = org.apache.hadoop.fs.FileSystem.get(    
          new java.net.URI(model_out_path), new org.apache.hadoop.conf.Configuration()) 
    if (!hdfs.exists(path)){
        hdfs.mkdirs(path)
    }else{
       hdfs.delete(path, true)  
       hdfs.mkdirs(path)
    }

    var bestscore=0.0
    for (numTrees <- treeList; maxDepth <- depthList) {
      val s = Strategy.defaultStrategy("Classification")
      s.setMaxMemoryInMB(2048)
      s.setNumClasses(numClasses)
      s.setMaxDepth(maxDepth)
      s.setMaxBins(maxBins)

      val model = RandomForest.trainClassifier(trainingData, s, numTrees, featureSubsetStrategy, 10)
      // 测试数据评价训练好的分类器并计算错误率
      val labelAndPreds = testData.map { point =>
        val prediction = model.predict(point.features)
        (point.label, prediction)
      }

      val quota = Quota.calculate(labelAndPreds, testData)
      val testErr = quota._1
      //      val testRecall = quota._3
      //      val f1_score = quota._4

      println("Test Error = " + testErr)
//            println("Learned classification forest model:\n" + model.toDebugString)

//      hdfs.createNewFile(new Path(describe + s"result-$numTrees-$maxDepth-$testErr.txt")) 

//      val dirfile = new File(describe);
//      if (!dirfile.isDirectory()) {
//        dirfile.mkdirs()
//      }
//      val resultfile = new File(describe + s"result-$numTrees-$maxDepth-$testErr.txt")
//      if(!resultfile.isFile()){
//        val writer = new PrintWriter(resultfile)
//        //      writer.println("train pos count:" + pos_sample + "\n")
//        //      writer.println("train neg count:" + neg_sample + "\n")
//        writer.println("train count:" + train_sample + "\n")
//        writer.println("test count:" + test_sample + "\n")
//        writer.println("Test Error = " + testErr + "\n")
//        writer.println(model.toDebugString)
//        writer.close()
//      }

      println(s"model-$numTrees-$maxDepth:"+(1-testErr))
      println(model.toDebugString)

      // 将训练后的随机森林模型持久化
      val now: Date = new Date()
      val dateFormat: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss")
      val date = dateFormat.format(now)
      val path = new Path(model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date");    
//      该参数模型不存在时,则保存模型
      val hdfs = org.apache.hadoop.fs.FileSystem.get(    
            new java.net.URI(model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date"), new org.apache.hadoop.conf.Configuration()) 
      if (!hdfs.exists(path)){
        model.save(sc, model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date")
      }    

      if(1-testErr>=bestscore){

        //该参数模型不存在时,则保存模型
        val path = new Path(model_file);    
        val hdfs = org.apache.hadoop.fs.FileSystem.get(    
              new java.net.URI(model_file), new org.apache.hadoop.conf.Configuration())      
        if (hdfs.exists(path)) hdfs.delete(path, true)    
        model.save(sc, model_out_path + "model-RF-best")
        bestscore = 1-testErr
      }

    }
    sc.stop()
    println("best score:"+bestscore)
    println("run done !!!!!!!!")
  }
}

2、随机森林预测的代码

package com.inspur.mr.InspurMr.Classification

import com.inspur.mr.InspurMr.conf.AppConf
import org.apache.spark.mllib.tree.model.RandomForestModel
import com.inspur.mr.InspurMr.Util.MLUtils
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}

object RandomPredict extends AppConf {
  case class TableMrPre(msisdn:String,imsi:String,imei:String,begintime:String,tac:String,eci:String,nettype0:String,long_uri:Double,lat_uri:Double)
  def main(args: Array[String]): Unit = {
    val database = paraproperties.getProperty("database")
    val null_fill = paraproperties.getProperty("null_fill")
    val eare_lon_left = paraproperties.getProperty("eare_lon_left").toDouble
    val eare_lat_left = paraproperties.getProperty("eare_lat_left").toDouble
    val eare_lon_right = paraproperties.getProperty("eare_lon_right")
    val eare_lat_right = paraproperties.getProperty("eare_lat_right")
    val grid_length = paraproperties.getProperty("grid_length")
    val grid_num = paraproperties.getProperty("grid_num").toDouble
    val disgrid = grid_length.toDouble*0.000009
    val disgridhalf = grid_length.toDouble*0.000009/2
    var HOUR_ID = args(0)
    var MONTH_ID = HOUR_ID.substring(0,6)
    var DAY_ID = HOUR_ID.substring(0,8)
    val write_partition = "month_id="+MONTH_ID+","+"day_id="+DAY_ID+","+"hour_id="+HOUR_ID
    val read_partition = "month_id="+MONTH_ID+" and "+"day_id="+DAY_ID+" and "+"hour_id="+HOUR_ID

    conf.setAppName("family_test")
    val pModlePath = postgprop.getProperty("model_file")

    hc.sql(s"use $database")
    val data = hc.sql(s"""select cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3,msisdn,imsi,imei,begintime,tac,eci,nettype0 from dw_pods_mro_eutrancell_pre_yyyymmdd where $read_partition""".stripMargin)
    println(s"""select cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3,msisdn,imsi,imei,begintime,tac,eci,nettype0 from dw_pods_mro_eutrancell_pre_yyyymmdd where $read_partition""")
    println("run here1 !!!!!!!!")
    //data.show()   //
    val sameModel = RandomForestModel.load(sc, pModlePath)
    println("run here2!!!!!")
    val labelAndPreds = data.map { row =>
      def isNull(xarr:Any):String = if (null==xarr) "-2" else xarr.toString() 
      val rowStr = isNull(row(0))+" "+isNull(row(1))+" "+isNull(row(2))+" "+isNull(row(3))+" "+isNull(row(4))+" "+isNull(row(5))+" "+isNull(row(6))+" "+isNull(row(7))+" "+isNull(row(8))+" "+isNull(row(9))+" "+isNull(row(10))+" "+isNull(row(11))+" "+isNull(row(12))+" "+isNull(row(13))+" "+isNull(row(14))+" "+isNull(row(15))+" "+isNull(row(16))+" "+isNull(row(17))+" "+isNull(row(18))+" "+isNull(row(19))+" "+isNull(row(20))+" "+isNull(row(21))+" "+isNull(row(22))+" "+isNull(row(23))+" "+isNull(row(24))+" "+isNull(row(25))+" "+isNull(row(26))    
      val prediction = sameModel.predict(Vectors.dense(rowStr.split(' ').map { _.toDouble }))
      val glong = prediction%grid_num
      val glat = prediction/grid_num
      val lonPre=glong*disgrid+eare_lon_left+disgridhalf
      val latPre=eare_lat_left-glat*disgrid-disgridhalf
      TableMrPre(isNull(row(27)),isNull(row(28)),isNull(row(29)),isNull(row(30)),isNull(row(31)),isNull(row(32)),isNull(row(33)),lonPre,latPre)
    }.cache
    println("run here4!!!!!")

    import hc.implicits._ 
    val tabledf = labelAndPreds.toDF()
//    tabledf.show(100)
    tabledf.registerTempTable("TempTableMrPre")
    hc.sql("insert OVERWRITE table dw_mr_mme_position_pre partition("+write_partition+") select * from TempTableMrPre")  
    hc.dropTempTable("TempTableMrPre")
    sc.stop()
    println("run done!!!!!")

  }

}

你可能感兴趣的:(机器学习,hadoop,spark)