spark工具类(常用)

package com.fengtu.sparktest.utils

import java.sql.DriverManager
import java.util.{Map, Properties}

import com.alibaba.fastjson.{JSONArray, JSONObject}
import com.fengtu.sparktest.utils2.Utils
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

object SparkUtils {
  //aggregateByKey的柯里化函数,将JSONObject聚合成List[JSONObject]
  val seqOp = (a: List[JSONObject], b: JSONObject) => a.size match {
    case 0 => List(b)
    case _ => b::a
  }

  val combOp = (a: List[JSONObject], b: List[JSONObject]) => {
    a ::: b
  }

  val seqOpRow = (a: List[Row], b: Row) => a.size match {
    case 0 => List(b)
    case _ => b::a
  }

  val combOpRow = (a: List[Row], b: List[Row]) => {
    a ::: b
  }

  //进行post请求,失败后重试一次
  def doPost(url:String,reqJson:JSONObject,logger:Logger) = {
    var resbonseBody = "{}"

    try {
      resbonseBody = Utils.post(url, reqJson, "utf-8")
    } catch {
      case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJson")
        try {
          resbonseBody = Utils.post(url, reqJson, "utf-8")
        } catch {
          case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJson")
            resbonseBody = "发送post请求失败:" + e.toString
        }
    }
    resbonseBody
  }

  //进行post表单请求,失败后重试一次
  def doPostForm(url: String, map: Map[String, String], logger: Logger):String = {
    var resbonseBody = "{}"

    try {
      resbonseBody = Utils.postForm(url, map, "utf-8")
    } catch {
      case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n${map.toString}")
        try {
          resbonseBody = Utils.postForm(url, map, "utf-8")
        } catch {
          case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n${map.toString}")
            resbonseBody = "发送post请求失败:" + e.toString
        }
    }

    resbonseBody
  }


  /**
    * 进行post请求,失败后重试一次
    * @param url
    * @param reqJsonArr
    * @param logger
    * @return
    */
  def doPostArr(url:String, reqJsonArr:JSONArray, logger:Logger) = {
    var resbonseBody = "{}"

    try {
      resbonseBody = Utils.post(url, reqJsonArr, "utf-8")
    } catch {
      case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJsonArr")
        try {
          resbonseBody = Utils.post(url, reqJsonArr, "utf-8")
        } catch {
          case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJsonArr")
            resbonseBody = "发送post请求失败:" + e.toString
        }
    }
    resbonseBody
  }

  def df2Hive(spark:SparkSession,rdd: RDD[Row],schema:StructType,saveMode:String,descTable:String,partitionSchm:String,incDay:String,logger: Logger): Unit = {
    logger.error(s"写入hive ${descTable}中...")
    val df = spark.sqlContext.createDataFrame(rdd,schema)
    //写入前删除分区数据
    val dropSql = s"alter table $descTable drop if exists partition($partitionSchm='$incDay')"
    logger.error(dropSql)
    spark.sql(dropSql)

    df.write.format("hive").mode(saveMode).partitionBy(partitionSchm).saveAsTable(descTable)

    logger.error(s"写入分区${incDay}成功")
  }

  def df2HivePs(spark:SparkSession,rdd: RDD[Row],schema:StructType,saveMode:String,descTable:String,incDay:String,region:String,logger: Logger,partitionDay:String,partitionRegion:String): Unit = {
    logger.error(s"写入hive ${descTable}中...")
    val df = spark.sqlContext.createDataFrame(rdd,schema)
    //写入前删除分区数据
    val dropSql = s"alter table $descTable drop if exists partition($incDay='$partitionDay',$region='$partitionRegion')"
    logger.error(dropSql)
    spark.sql(dropSql)

    df.write.format("hive").mode(saveMode).partitionBy(incDay,region).saveAsTable(descTable)

    logger.error(s"写入分区 $partitionDay,$partitionRegion 成功")
  }

  def df2Hive(spark:SparkSession, df:DataFrame,saveMode:String,descTable:String,partitionSchm:String,incDay:String,logger: Logger): Unit = {
    logger.error(s"写入hive ${descTable}中...")

    //写入前删除分区数据
    val dropSql = s"alter table ${descTable} drop if exists partition($partitionSchm='$incDay')"
    logger.error(dropSql)
    spark.sql(dropSql)

    df.write.format("hive").mode(saveMode).partitionBy(partitionSchm).saveAsTable(descTable)

    logger.error(s"写入分区${partitionSchm}成功")
  }


  def getRowToJson(sourDf:DataFrame,parNum:Int=200 ) ={

    val colList = sourDf.columns

    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
      val jsonObj = new JSONObject()
      for (columns <- colList) {
        jsonObj.put(columns,obj.getAs[String](columns))
      }
      jsonObj
    }).persist(StorageLevel.DISK_ONLY)

    //println(s"共获取数据:${sourRdd.count()}")

    sourDf.unpersist()

    sourRdd
  }


  /**
    *
    * @param spark
    * @param rdd
    * @param schema
    * @param user
    * @param password
    * @param saveMode
    * @param jdbcUrl
    * @param tblName
    * @param incDay
    * @param logger
    * @param statdate
    */
  def df2Mysql(spark: SparkSession,rdd: RDD[Row],schema:StructType,user:String,password:String,
               saveMode:String,jdbcUrl:String,tblName:String,incDay: String, logger: Logger,statdate:String = "statdate"): Unit = {

    val delSql = String.format(s"delete from $tblName where %s='%s'",statdate,incDay)
    logger.error(">>>保存之前,删除当天的数据:" + delSql)
    Class.forName("com.mysql.jdbc.Driver");
    val conn = DriverManager.getConnection(jdbcUrl,user,password)
    DbUtils.executeSql(conn, delSql)
    //conn.close()

    //创建临时表
    val tmpTbl = spark.sqlContext.createDataFrame(rdd,schema).persist()

    //创建Properties存储数据库相关属性
    val prop = new Properties()
    prop.setProperty("user", user)
    prop.setProperty("password", password)

    logger.error(s"正在写入mysql:${tblName}")

    //将数据追加到数据库
    tmpTbl.write.mode(SaveMode.Append).jdbc(jdbcUrl,tblName,prop)

    DbUtils.querySql(conn,String.format(s"select count(1) from $tblName where %s='%s'",statdate,incDay))
    conn.close()
    tmpTbl.unpersist()
  }


  /**
    * 随机散列数据后做聚合
    *
    * @param obj     输入数据
    * @param hashNum 散列倍数,将随机一定范围内的随机值作为散列前缀
    */
  def groupByKeyTwoStep(obj: RDD[(String, Object)], hashNum: Int): Unit = {
    // 先添加随机值散列,第一次聚合
    val hashData = obj.map(obj => {
      val hashPrefix = new Random().nextInt(hashNum)
      ((hashPrefix, obj._1), obj._2)
    }).groupByKey().map(obj => {
      (obj._1._2, obj._2.toArray)
    })
    //再去除散列进行第二次聚合
    hashData.groupByKey().map(obj => {
      val key = obj._1
      val valueIterator = obj._2.iterator
      val ret = new ArrayBuffer[Object]
      while (valueIterator.hasNext) {
        val tmpArray = valueIterator.next()
        ret.appendAll(tmpArray)
      }
      (key, ret)
    })
  }

//  def getRowToJson(spark:SparkSession,querySql:String,parNum:Int=200 ) ={
//    val sourDf = spark.sql(querySql).persist(StorageLevel.DISK_ONLY)
//
//    val colList = sourDf.columns
//
//    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
//      val jsonObj = new JSONObject()
//      for (columns <- colList) {
//        jsonObj.put(columns,obj.getAs[String](columns))
//      }
//      jsonObj
//    }).persist(StorageLevel.DISK_ONLY)
//
//    println(s"共获取数据:${sourRdd.count()}")
//
//    sourDf.unpersist()
//
//    sourRdd
//  }

//
//  def getRowToJsonTurp( spark:SparkSession,querySql:String,key:String,parNum:Int=200 ) ={
//    val sourDf = spark.sql(querySql).persist(StorageLevel.DISK_ONLY)
//
//    val colList = sourDf.columns
//
//    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
//      val jsonObj = new JSONObject()
//      for (columns <- colList) {
//        jsonObj.put(columns,obj.getAs[String](columns))
//      }
//      (jsonObj.getString(key),jsonObj)
//    }).persist(StorageLevel.DISK_ONLY)
//
//    sourRdd
//  }
}

你可能感兴趣的:(spark,spark)