Spark代码可读性与性能优化——示例八(一个业务逻辑,多种解决方式)

文章目录

  • Spark代码可读性与性能优化——示例八(一个业务逻辑,多种解决方式)
    • 1. 前情提要
    • 2. 需求展示
    • 3. 问题分析
      • 3.1 问题一(SQL性能较低)
      • 3.2 问题二(数据倾斜)
      • 3.3 问题三(数据倾斜内的数据倾斜)
    • 4. 多种解决方式的示例
      • 4.1 利用将随机数添加到key上的方式,来解决数据倾斜的问题
      • 4.2 使用reduceByKey,修改key数据结构,再更改后续处理方式
      • 4.3 不修改key数据结构,编写聚合器
      • 4.4 不修改key数据结构,value存Map<字段值,字段值的数量>
      • 4.5 不修改key数据结构,value存(Set<字段值>, 1)
      • 4.6 其他方案,使用 treeAggregate

Spark代码可读性与性能优化——示例八(一个业务逻辑,多种解决方式)

1. 前情提要

  • 在示例七的末尾中提出了一个需求“同时统计某个表所有字段对应的值的总数、去重后的总数,并要求对应字段值非空”。如果你看过示例七,显然应该知道怎么解决。
  • 写这篇文章的目的如下:
    • 再详细描述业务需求,以免误解
    • 提供对业务问题点的分析
    • 展示多种解决方案的示例

2. 需求展示

  • 现有一张表 tb_express,示例如下:
name address trade fix express
王五 四川成都市 d72_network ty002 zto
null 重庆渝北 z03_locker bk213 sf-express
李雷 湖南长沙 null null sf-express
…… …… …… …… ……
李四 广东广州 t92_locker tu87 sto
  • 表信息描述
    • 全表一共50个字段,这里只展示了5个(后面为了方便展示,也只会以这5个作为示例)
    • 全表共计10亿条数据
    • 由于数据源问题,很多字段会存在null值情况。(据最后统计知:fix字段共计9.5亿非空值,其他字段非空值总数在2-3亿范围)
  • 业务需求:
    • 需要统计出所有字段值的总数字段值去重后的总数,并要求字段非空

3. 问题分析

3.1 问题一(SQL性能较低)

  • 其实业务需求本身挺简单的,首先可能第一个想到的就是用SQL进行处理,示例如下:
    SELECT count(name), count(distinct name)
    FROM tb_express
    WHERE name IS NOT NULL AND name != '';
    
  • 但是,因为每个字段的null都不一样,所以SQL没法一次统计完所有字段(除非编写UDAF,但是这样的话你还是要写代码),50个字段都得跑一次。可以看到这样做的缺点:集群资源消耗大,花费时间长。

3.2 问题二(数据倾斜)

  • 所以,你可能会尝试编写代码一次解决问题。一般的编写示例如下:
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.types.{StringType, StructField, StructType}
    
    /**
      * Description: '字段值总数'与'字段值去重后的总数'的统计(错误示例)
      * 
    * Date: 2019/12/2 16:57 * * @author ALion */ object CountDemo { val expressSchema: StructType = StructType(Array( StructField("name", StringType), StructField("address", StringType), StructField("trade", StringType), StructField("fix", StringType), StructField("express", StringType) )) def main(args: Array[String]): Unit = { val conf = new SparkConf() .setAppName("CountDemo") val spark = SparkSession.builder() .config(conf) .getOrCreate() val expressDF = spark.read.schema(expressSchema).table("tb_express") val resultRDD = expressDF.rdd // .flatMap { row => // val name = row.get(row.fieldIndex("name")) // val address = row.get(row.fieldIndex("address")) // val trade = row.get(row.fieldIndex("trade")) // val fix = row.get(row.fieldIndex("fix")) // val express = row.get(row.fieldIndex("express")) // // val buffer = ArrayBuffer[(String, String)]() // // 去除null值 // // 字段名设置为key,字段值设置为value // if (name != null) buffer.append(("name", name.toString)) // if (address != null) buffer.append(("address", address.toString)) // if (trade != null) buffer.append(("trade", trade.toString)) // if (fix != null) buffer.append(("fix", fix.toString)) // if (express != null) buffer.append(("express", express.toString)) // // buffer // } .flatMap { row => // 更函数式的写法,也更简短 Array("name", "address", "trade", "fix", "express") .flatMap { name => Option(row.get(row.fieldIndex(name))) match { case Some(v) => Some((name, v.toString)) case None => None } // 还有更短的写法=.= // Option(row.get(row.fieldIndex(name))).map(v => (name, v.toString)) } }.groupByKey() .mapValues { iter => (iter.size, iter.toSet.size) } // 此处计算'字段值总数'与'字段值去重后的总数' // 拉取数据,打印结果 resultRDD.collect() .foreach { case (fieldName, (count, distinctCount)) => println(s"字段名 = $fieldName, 字段值总数 = $count, 字段值去重后的总数 = $distinctCount") } spark.stop() } }
  • 编写代码,一般首先想到的都是这种方式,对每个字段名进行分组,分别统计字段值就可以了。
  • 但是,此处代码的问题在于groupByKey导致了数据倾斜,因为前面咱们提到过“fix字段共计9.5亿非空值,其他字段非空值总数在2-3亿范围”。那么,如果以字段值为key,进行groupByKey,会导致shuffle到某个节点数据远远大于其他节点。

3.3 问题三(数据倾斜内的数据倾斜)

  • 这个时候你可能又会想了“我的key里除了放字段名,再放个字段值,让value存1,这样不就不会倾斜了?”。主体部分示例如下:
        val resultRDD = expressDF.rdd
          .flatMap { row =>
            // 更函数式的写法,也更简短
            Array("name", "address", "trade", "fix", "express")
              .flatMap { name =>
                Option(row.get(row.fieldIndex(name))) match {
                  case Some(v) => Some(((name, v.toString), 1))
                  case None => None
                }
                // 还有更短的写法=.=
               // Option(row.get(row.fieldIndex(name))).map(v => ((name, v.toString), 1))
              }
          }.groupByKey()
          .map { case ((name, value), iter) => (name, (value, iter.size)) }
          .groupByKey() // 第二次groupByKey虽然还是以字段名为key,但是因为数据量很小,所以会很快处理完
          .map {case (name, iter) =>
            // (字段名, 字段值总数, 字段值去重后的总数)
            (name, iter.map(_._2).sum, iter.size)
          }
    
  • 但是,前面隐藏了一个秘密,那就是“fix字段共计9.5亿非空值,并且有80%都是相同值,同时业务要求不能排除这些值”,所以一旦运行起来,还是会数据倾斜,卡在一个点上。(不过你显然已经想到用reduceByKey来解决了)
  • 另外,你可以用数字编号(Int)来代表字段名(String),以降低内存消耗。不过后续示例为了方便查看,还是使用的字段名。

4. 多种解决方式的示例

  • 看了前面的需求分析,显然我们知道问题主要在于数据倾斜,得想办法解决它。那么一般的解决方式都是reduceByKey,不过我们也可以想其他办法,下面我就编写一些示例,以供参考。

4.1 利用将随机数添加到key上的方式,来解决数据倾斜的问题

val resultRDD = expressDF.rdd
  .flatMap { row =>
    val random = new Random()
    // 更函数式的写法,也更简短
    Array("name", "address", "trade", "fix", "express")
      .flatMap { name =>
        Option(row.get(row.fieldIndex(name))) match {
          // 随机范围取100,最终会导致数据分成100份。根据当前集群启动的节点数合理取值,可以达到更好的效果。
          case Some(v) => Some((name + "_" + random.nextInt(100), v.toString))
          case None => None
        }
        // 还有更短的写法=.=
        // Option(row.get(row.fieldIndex(name))).map(v => (name + "_" + random.nextInt(100), v.toString))
      }
  }.groupByKey()
  .map { case (k, v) =>
    // 完成本次聚合,并去掉随机数
    (k.split("_")(0), (v.size, v.toSet))
  }.groupByKey() // 同样的,你也可以写reduceByKey,不过此处几乎没有效率影响
    .mapValues {iter =>
      // (字段值总数, 字段值去重后的总数)
      (iter.map(_._1).sum, iter.map(_._2).reduce(_ ++ _).size)
    }

4.2 使用reduceByKey,修改key数据结构,再更改后续处理方式

val resultRDD = expressDF.rdd
  .flatMap { row =>
    // 更函数式的写法,也更简短
    Array("name", "address", "trade", "fix", "express")
      .flatMap { name =>
        Option(row.get(row.fieldIndex(name))) match {
          case Some(v) => Some(((name, v.toString), 1))
          case None => None
        }
        // 更短
        // Option(row.get(row.fieldIndex(name))).map(v => ((name, v.toString), 1))
      }
  }.reduceByKey(_ + _) // 将问题分析最后示例中的groupByKey替换为reduceByKey即可解决
  .map { case ((name, value), count) => (name, (value, count)) }
  // 第二次groupByKey虽然还是以字段名为key,但是因为数据量很小,所以会很快处理完。
  // 当然你这里也可以使用reduceByKey。
  .groupByKey()  
  .map { case (name, iter) =>
    // (字段名, 字段值总数, 字段值去重后的总数)
    (name, iter.map(_._2).sum, iter.size)
  }

4.3 不修改key数据结构,编写聚合器

  • 聚合器类 CountAggregator
    class CountAggregator(var count: Int, var countSet: mutable.HashSet[String]) {
      
      def +=(element: (Int, String)) : CountAggregator = {
        this.count += element._1
        this.countSet += element._2
    
        this
      }
      
      def ++=(that: CountAggregator): CountAggregator = {
        this.count += that.count
        this.countSet ++= that.countSet
    
        this
      }
    
    }
    
    object CountAggregator {
    
      def apply(): CountAggregator =
        new CountAggregator(0, mutable.HashSet[String]())
    
    }
    
  • Spark主体代码
    val resultRDD = expressDF.rdd
      .flatMap { row =>
        // 更函数式的写法,也更简短
        Array("name", "address", "trade", "fix", "express")
          .flatMap { name =>
            Option(row.get(row.fieldIndex(name))) match {
              case Some(v) => Some((name, (1, v.toString)))
              case None => None
            }
            // 更短
            // Option(row.get(row.fieldIndex(name))).map(v => (name, (1, v.toString)))
          }
      }.aggregateByKey(CountAggregator())(
        (agg, v) => agg += v,
        (agg1, agg2) => agg1 ++= agg2
      ).mapValues { aggregator =>
        // (字段值总数, 字段值去重后的总数)
        (aggregator.count, aggregator.countSet.size)
      }
    

4.4 不修改key数据结构,value存Map<字段值,字段值的数量>

 val resultRDD = expressDF.rdd
   .flatMap { row =>
     // 更函数式的写法,也更简短
     Array("name", "address", "trade", "fix", "express")
       .flatMap { name =>
         Option(row.get(row.fieldIndex(name))) match {
           case Some(v) => Some((name, (v.toString, 1)))
           case None => None
         }
         // 更短
         // Option(row.get(row.fieldIndex(name))).map(v => (name, (v.toString, 1)))
       }
   }.aggregateByKey(mutable.HashMap[String, Int]())(
     (map, kv) => map += (kv._1 -> (map.getOrElse(kv._1, 0) + kv._2)),
     (map1, map2) => {
       for ((k, v) <- map2) {
         map1 += (k -> (map1.getOrElse(k, 0) + v))
       }
       map1
     }
   ).mapValues { map => 
      // 字段值总数, 字段值去重后的总数
        (map.values.sum, map.keySet.size)
   }

4.5 不修改key数据结构,value存(Set<字段值>, 1)

 val resultRDD = expressDF.rdd
   .flatMap { row =>
     // 更函数式的写法,也更简短
     Array("name", "address", "trade", "fix", "express")
       .flatMap { name =>
         Option(row.get(row.fieldIndex(name))) match {
           case Some(v) => Some((name, (v.toString, 1)))
           case None => None
         }
         // 更短
         // Option(row.get(row.fieldIndex(name))).map(v => (name, (v.toString, 1)))
       }
   }.aggregateByKey((mutable.HashSet[String](), 0))(
     // 会提示错误,但是能通过编译
     (agg, v) => (agg._1 += v._1, agg._2 + v._2),
     (agg1, agg2) => (agg1._1 ++= agg2._1, agg1._2 + agg2._2)
   )
   /*.aggregateByKey((mutable.HashSet[String](), 0))(
     // 解决误提示的方法1
     (agg, v) => {
       agg._1 += v._1
       (agg._1 , agg._2 + v._2)
     },
     (agg1, agg2) => {
       agg1._1 ++= agg2._1
       (agg1._1 , agg1._2 + agg2._2)
     }
   )*/
   /*.aggregateByKey((mutable.HashSet[String](), 0))(
        // 解决错误提示的方法2
        (agg, v) => ((agg._1 += v._1).asInstanceOf[mutable.HashSet[String]] , agg._2 + v._2),
        (agg1, agg2) => ((agg1._1 ++= agg2._1).asInstanceOf[mutable.HashSet[String]], agg1._2 + agg2._2)
      )*/
    /*.aggregateByKey((immutable.HashSet[String](), 0))(
       // 解决错误提示的方法3
       (agg, v) => (agg._1.+(v._1) , agg._2 + v._2),
       (agg1, agg2) => (agg1._1.++(agg2._1), agg1._2 + agg2._2)
     )*/
   .mapValues { case (set, count)  =>
     // (字段值总数, 字段值去重后的总数)
     (count, set.size)
   }

4.6 其他方案,使用 treeAggregate

  • 示例代码如下 (针对当前业务逻辑,不推荐使用treeAggregate)
  • Aggregator
    import scala.collection.mutable
    
    case class Counter(var count: Int, set: mutable.HashSet[String])
    
    class CountAggregator2(var counter1: Counter,
                           var counter2: Counter,
                           var counter3: Counter,
                           var counter4: Counter,
                           var counter5: Counter) {
    
      def +=(element: (Any, Any, Any, Any, Any)): CountAggregator2 = {
        def countFunc(counter: Counter, e: Any): Unit = {
          if (e != null) {
            counter.count += 1
            counter.set += e.toString
          }
        }
    
        countFunc(counter1, element._1)
        countFunc(counter2, element._2)
        countFunc(counter3, element._3)
        countFunc(counter4, element._4)
        countFunc(counter5, element._5)
    
        this
      }
    
      def ++=(that: CountAggregator2): CountAggregator2 = {
        this.counter1.count += that.counter1.count
        this.counter1.set ++= that.counter1.set
        this.counter2.count += that.counter2.count
        this.counter2.set ++= that.counter2.set
        this.counter3.count += that.counter3.count
        this.counter3.set ++= that.counter3.set
        this.counter4.count += that.counter4.count
        this.counter4.set ++= that.counter4.set
        this.counter5.count += that.counter5.count
        this.counter5.set ++= that.counter5.set
    
        this
      }
    
    }
    
    
    object CountAggregator2 {
    
      def apply(): CountAggregator2 =
        new CountAggregator2(
          Counter(0, mutable.HashSet[String]()),
          Counter(0, mutable.HashSet[String]()),
          Counter(0, mutable.HashSet[String]()),
          Counter(0, mutable.HashSet[String]()),
          Counter(0, mutable.HashSet[String]())
        )
    
    }
    
  • spark主体代码
     val result = expressDF.rdd
       .map { row =>
         val rowAny = (name: String) => row.get(row.fieldIndex(name))
         
         (rowAny("name"), rowAny("address"), rowAny("trade"), rowAny("fix"), rowAny("express"))
       }
       // 请根据业务、数据量,来调整treeAggregate的深度depth,默认为2
       .treeAggregate(CountAggregator2())(
         (agg, v) => agg += v,
         (agg1, agg2) => agg1 ++= agg2
       )
    
  • 优点
    • treeAggregate利用了并行计算,先在每个节点进行reduce,最后再合并到一个节点进行reduce
    • 不用像前面的示例一样,为每个字段名生成一个值(一行数据有多少字段,几乎就会生成多少个字段名),用作聚合的key,减少了内存空间占用
  • 缺点
    • 需要尽可能让最后reduce结果的数据量变小,例如是多个基本的值(最值、计数值)、数据量较少的集合(排序前10名)
    • 此处示例,最后每个字段聚合了一个去重的Set,如果Set内数据量较大,内存占用变多,那么可能会导致driver端挂掉

你可能感兴趣的:(Scala,BigData,#,Spark)