Spark_SparkSQL / DataFrame 中 groupby 数据倾斜处理方法

 

数据倾斜,是一个有可能遇到的问题,Hive 中 groupby 数据倾斜, 已经有参数可以很好的支持了。Hive 参考文章

https://blog.csdn.net/u010003835/article/details/105495135

下面我们看下 SparkSQL 如何解决这种 GroupBy 类型的数据倾斜 

 

思路如下:

   其实是和 Hive 的 调优参数,将作业拆分为2个参数一样的。

  1. set hive.map.aggr=true;

  2. set hive.groupby.skewindata=true;

  增加一个JOB做随机分组聚合后, 再根据中间结果按照预先的Key聚合

 

具体做法

   1.根据 生成一个(0~10)随机数列

   2.根据随机数列,和预先要聚合的key进行聚合

  3.按照预先要聚合的key进行聚合

 

具体实现

具体的做法又有两种方式

1.

//方法一 : DataFrame Functions && UDF

2.

//方法二 : SQL + SQL Functions

实现代码

package com.spark.test.offline.skewed_data

import java.util.Random

import org.apache.spark.SparkConf
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

/**
  * Created by szh on 2020/5/29.
  */
object GroupBySkewedData {


  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf
    sparkConf
      .setAppName("Union data test")
      .setMaster("local[1]")

    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()

    val sparkContext = spark.sparkContext
    sparkContext.setLogLevel("WARN")


    val arrayA = Array(
      (1, "mm", BigDecimal.valueOf(33.2))
      , (2, "cs", BigDecimal.valueOf(22.1))
      , (3, "cc", BigDecimal.valueOf(22.2))
      , (4, "px", BigDecimal.valueOf(22))
      , (5, "kk", BigDecimal.valueOf(22))
      , (2, "cs", BigDecimal.valueOf(22)))


    val rddA = sparkContext
      .parallelize(arrayA)
      .map(x => Row(x._1, x._2, x._3))
    // .parallelize(arrayA, 4)
    //parallelize 第二个参数实际指定了并行度

    println("rddA partition num :" + rddA.partitions.length)

    val rddAStruct = StructType(
      Array(
        StructField("uid", IntegerType, nullable = true)
        , StructField("name", StringType, nullable = true)
        , StructField("money", DecimalType.SYSTEM_DEFAULT, nullable = true)
      )
    )

    val rddADF = spark.createDataFrame(rddA, rddAStruct)
    rddADF.createOrReplaceTempView("tmpA")


    //定义UDF
    val rand = (arg: Int) => {
      val random = new Random()
      random.nextInt(10)
    }

    val randUdf = udf(rand)


    //方法一 : DataFrame Functions && UDF
    val midDF = rddADF.withColumn("salt", randUdf(rddADF("uid")))
      .groupBy("salt", "uid", "name")
      .agg(Map("money" -> "sum"))

    val resultDF = midDF
      .groupBy("uid", "name")
      .sum("sum(money)")
      .toDF("uid", "name", "total_money")


    println("resultDF's rdd partition num :" + resultDF.rdd.partitions.length)
    resultDF.explain()
    resultDF.show()

    System.out.println("  ")
    System.out.println(" -----------------------------  ")
    System.out.println(" -----------------------------  ")
    System.out.println(" -----------------------------  ")
    System.out.println(" -----------------------------  ")
    System.out.println(" -----------------------------  ")
    System.out.println("  ")

    //方法二 : SQL + SQL Functions
    spark
      .sql("SELECT uid, name, money, cast(rand() * 10 as int) as salt  " +
        "FROM tmpA ")
      .createOrReplaceTempView("midResult")

    val resultDF2 = spark.sql("" +
      "SELECT uid, name, SUM(mid_money) AS total_money " +
      "FROM ( " +
      " SELECT uid, name, salt, SUM(money) AS mid_money " +
      " FROM midResult " +
      " GROUP BY uid, name, salt " +
      " ) AS tmp " +
      "GROUP BY uid, name "
    )

    resultDF2.explain()
    resultDF2.show()

    Thread.sleep(60 * 10 * 1000)

    sparkContext.stop()
  }
}

 

输出

rddA partition num :1
resultDF's rdd partition num :200
== Physical Plan ==
*HashAggregate(keys=[uid#3, name#4], functions=[sum(sum(money)#22)])
+- Exchange hashpartitioning(uid#3, name#4, 200)
   +- *HashAggregate(keys=[uid#3, name#4], functions=[partial_sum(sum(money)#22)])
      +- *HashAggregate(keys=[salt#11, uid#3, name#4], functions=[sum(money#5)])
         +- Exchange hashpartitioning(salt#11, uid#3, name#4, 200)
            +- *HashAggregate(keys=[salt#11, uid#3, name#4], functions=[partial_sum(money#5)])
               +- *Project [uid#3, name#4, money#5, if (isnull(uid#3)) null else UDF(uid#3) AS salt#11]
                  +- Scan ExistingRDD[uid#3,name#4,money#5]
+---+----+--------------------+
|uid|name|         total_money|
+---+----+--------------------+
|  3|  cc|22.20000000000000...|
|  4|  px|22.00000000000000...|
|  1|  mm|33.20000000000000...|
|  2|  cs|44.10000000000000...|
|  5|  kk|22.00000000000000...|
+---+----+--------------------+

  
 -----------------------------  
 -----------------------------  
 -----------------------------  
 -----------------------------  
 -----------------------------  
  
== Physical Plan ==
*HashAggregate(keys=[uid#3, name#4], functions=[sum(mid_money#65)])
+- Exchange hashpartitioning(uid#3, name#4, 200)
   +- *HashAggregate(keys=[uid#3, name#4], functions=[partial_sum(mid_money#65)])
      +- *HashAggregate(keys=[uid#3, name#4, salt#58], functions=[sum(money#5)])
         +- Exchange hashpartitioning(uid#3, name#4, salt#58, 200)
            +- *HashAggregate(keys=[uid#3, name#4, salt#58], functions=[partial_sum(money#5)])
               +- *Project [uid#3, name#4, money#5, cast((rand(-7591047829286253872) * 10.0) as int) AS salt#58]
                  +- Scan ExistingRDD[uid#3,name#4,money#5]
+---+----+--------------------+
|uid|name|         total_money|
+---+----+--------------------+
|  3|  cc|22.20000000000000...|
|  4|  px|22.00000000000000...|
|  1|  mm|33.20000000000000...|
|  2|  cs|44.10000000000000...|
|  5|  kk|22.00000000000000...|
+---+----+--------------------+

 

 

 

你可能感兴趣的:(Spark)