Spark代码可读性与性能优化——示例六(groupBy、reduceByKey、aggregateByKey)

文章目录

  • Spark代码可读性与性能优化——示例六(GroupBy、ReduceByKey)
    • 1. 普通常见优化示例
      • 1.1 错误示例 groupByKey
      • 1.2 正确示例 reduceByKey
    • 2. 高级优化
      • 2.0. 需求:统计历年全国高考生中数学成绩前100名
      • 2.1 数据示例
      • 2.2 存在问题的代码示例
      • 2.3 如何解决代码中的问题?
      • 2.4 最终代码,以及其他附件代码

Spark代码可读性与性能优化——示例六(GroupBy、ReduceByKey)

1. 普通常见优化示例

1.1 错误示例 groupByKey

import org.apache.spark.{SparkConf, SparkContext}

object GroupNormal {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("GroupNormal")
    val sc = new SparkContext(conf)

    // 数据可能有几亿条,此处只做模拟示例
    val dataRDD = sc.parallelize(List(
      ("hello", 2),
      ("java", 7),
      ("where", 1),
      ("rust", 2),
      // 中间还有很多数据,不做展示
      ("scala", 1),
      ("java", 1),
      ("black", 9)
    ))

    // 做一个词频统计
    val result = dataRDD.groupByKey()
      .mapValues(_.sum)
      .sortBy(_._2, false)

    result.take(10).foreach(println)

    sc.stop()
  }

}

1.2 正确示例 reduceByKey

    // 修改此部分groupByKey代码为reduceByKey
    val result = dataRDD
      .reduceByKey(_ + _)
      .sortBy(_._2, false)

    result.take(10).foreach(println)

2. 高级优化

2.0. 需求:统计历年全国高考生中数学成绩前100名

2.1 数据示例

id chinese math english year
3412312 121 115 134 2018
5231211 103 131 114 2010
…… …… …… …… ……
2342354 134 105 124 2014
  • 共计约2亿条数据
  • 数据存于Hive中,表名tb_student_score,id值(唯一)代表学生,chinese代表语文,math代表数学,english代表英语

2.2 存在问题的代码示例

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

/**
  * 数据分组错误示例
  *
  * @author ALion
  * @version 2019/5/15 22:33
  */
object GroupDemo {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("GroupDemo")
    val spark = SparkSession.builder()
      .config(conf)
      .enableHiveSupport()
      .getOrCreate()

    // 获取原始数据
    val studentDF = spark.sql(
      """
        |SELECT * 
        |FROM tb_student_score
        |WHERE id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL
      """.stripMargin)
    
    // 开始进行分析
    val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, (id, math))
      })
      .groupByKey() // 按年分组
      .mapValues(_.toSeq.sortWith(_._2 > _._2).take(100)) // 根据math对每个人进行降序排序,最后获取前100的人
    
    // 触发Action,展示部分统计结果
    resultRDD.take(10).foreach(println)

    spark.stop()
  }

}
  • 首先,可以肯定的是代码逻辑毫无问题,能够满足业务需求。
  • 其次,这部分代码又存在很大的性能问题:
    1. spark.sql("SELECT * FROM tb_student_score")这种形势读取表中数据较慢,有更快的方式
    2. groupByKey处,发生shuffle,大量数据被分到对应的年份的节点中,然后每个节点使用单线程在各年对应的所有数据中对学生进行排序,最后获取前100名
    3. groupByKey处的shuffle可能发生数据倾斜,可能存在部分年份的数据不全或参考人数较少,而部分年份数据较多
  • 另外,直接使用SQL的方案已附在文章末尾

2.3 如何解决代码中的问题?

  • 首先,读取表可以采用DataFrame的API,指定Schema,能够加速表的读取
	val tbSchema = StructType(Array(
	  StructField("id", LongType, true),
	  StructField("chinese", IntegerType, true),
	  StructField("math", IntegerType, true),
	  StructField("english", IntegerType, true),
	  StructField("year", IntegerType, true)
	))
	
	// 获取原始数据
	val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
	      .where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
  • 其次,关于groupBy发生shuffle的问题以及排序的问题。似乎数据如果不按年份分组,针对每年所有的分数统一排序,就没有其他办法。因为待排序的数据不在一起好像就不能完整的排序啊?那还怎么谈取前100名啊?
  • 其实不然,想想我们是不是可以先在每个数据分块本地排序一次获取前100名,最后将所有的前100汇总,进行一次总的排序获取总的前100名?这样的话,充分利用了每个分块的并行计算,提前做了部分排序,当数据shuffle的时候每个分块数据就只有100条,最后汇总进行一次排序的数据量就非常小了!其实这就是归并排序的思想,感兴趣的朋友可以搜索‘归并排序’看看。
  • 优化后的示例代码如下:
	// 开始进行分析
    val resultRDD = studentDF.rdd
      .mapPartitions {
        // 自己实现时,如果为了性能更好,不建议这样的函数式写法
        // 这里只是为了方便看
        _.map { row =>
          val id = row.getLong(row.fieldIndex("id"))
          val math = row.getInt(row.fieldIndex("math"))
          val year = row.getInt(row.fieldIndex("year"))

          (year, (id, math))
        }.toArray
          .groupBy(_._1) // 先在每个分块前,获取历年的数学前100名,减少后续groupBy的shuffle数据量
          .mapValues(_.map(_._2).sortWith(_._2 > _._2).take(100))
          .toIterator
      }.groupByKey() // 最后获取所有分块的前100名,再次排序,计算总的前100名
       .mapValues(_.flatten.toSeq.sortWith(_._2 > _._2).take(100))
    
    // 触发Action,展示部分统计结果
    resultRDD.take(10).foreach(println)
  • 上述代码,已经完成功能实现。那么,这样的代码是否是最好的呢?答案是否定的。因为当前的排序是针对每个分块(Partition)的,一个Executor上有多个分块,每个分块有前100条数据需要shuffle,显然如果一个Executor一共只有100条数据需要shuffle才是最理想的!如果我们能有办法同时操纵每个Executor上的所有数据,获取前100条数据,那该多好啊!

  • 我们想要的排序流程示意图如下:
    Spark代码可读性与性能优化——示例六(groupBy、reduceByKey、aggregateByKey)_第1张图片

  • 然而,Spark并没提供一个类似mapPartition的可以对Executor上所有分块统一操作的算子(不然的话,我们就可以像mapPartion那样统计每Executor的前100名了)。不过我们有一个算子reduceByKey,它会在每个节点合并数据后再shuffle到一个节点进行最后的合并,这种行为似乎与我们需要的逻辑类似,不过好像又有那么一点不一样。

  • 你可能会说reduceByKey是合并,而我们的需求是排序啊!!!是的,这看上去似乎有点矛盾。

  • 事实上,这样是行得通的:

    1. 首先,让我们假想有这样一个集合类型A(内部是可排序的,并且只能拥有前100的数据,多余的会被删除)
    2. 接着,把每个元素(id,math)转换成含有一个元素的集合A
    3. 最后,使用reduceByKey,将每个集合依次相加合并!!!没错!就是合并!这样最后一个集合就是包含前100名的集合了。
  • 这样一个集合类型A,似乎在Scala、Java中不存在,不过有一个TreeSet能保证内部有序,我们可以在数据合并后手动提取前100,这样就可以了(另外,你也可以自己实现这样一个集合:3)

  • 第一步,先将id和math转为一个对象,并为这个对象实现equals、hashCode、compareTo方法,保证后续在TreeSet中的排序不会出问题。另外,再实现一个toString方法,方便我们查看打印效果!:)

    • Person.class 代码 (因为Java比较易懂、易写这几个方法,这里优先采用Java的形式,后面会附上Scala对应的实现类)
    public class Person implements Comparable<Person>, Serializable {
    
        private long id;
    
        private int math;
    
        public Person(long id, int math) {
            this.id = id;
            this.math = math;
        }
    
        @Override
        public int compareTo(Person person) {
            int result = person.math - this.math; // 降序
            if (result == 0) {
                result = person.id - this.id > 0 ? 1 : -1;
            }
            return result;
        }
    
        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
    
            Person person = (Person) o;
    
            return id == person.id;
        }
    
        @Override
        public int hashCode() {
            return (int) (id ^ (id >>> 32));
        }
    
        @Override
        public String toString() {
            return "Person{" +
                    "id='" + id + '\'' +
                    ", math=" + math +
                    '}';
        }
    
    }
    
    • TreeSet 使用示例
    import scala.collection.immutable.TreeSet
    
    object Demo {
    
      def main(args: Array[String]): Unit = {
        val set = TreeSet[Person](
          new Person(1231232L, 108),
          new Person(3214124L, 116),
          new Person(1321313L, 121),
          new Person(6435235L, 125)
        )
    
        // 获取前3名
        for (elem <- set.take(3)) {
          println(s"--> elem = $elem")
        }
      }
    
    }
    
  • 第二步,将原先的id、math封装为TreeSet

	studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, TreeSet(new Person(id, math)))
      })
  • 最后,使用reduceByKey合并所有数据,得到前100名的结果
    val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, TreeSet(new Person(id, math)))
      })
      .reduceByKey((set1, set2) => set1 ++ set2 take 100)  // 依次合并2个Set,并只保留前100

    resultRDD.take(10).foreach(println)
  • Nice!!! 这样,我们就同时解决了排序问题和数据倾斜问题!
  • 进一步优化(aggregateByKey)
    • 细心的朋友应该已经发现了,reduceByKey之前的map为每条的数据都生成了一个TreeSet,这样会大大增加内存消耗。
    • 其实,我们只想要每个节点放一个可变的TreeSet(并且还能一直只存前100)。这样内存消耗就会更小!
    • 那么我们该如何做呢?设计一个MyTreeSet,采用aggregateByKey复用同一个Set,简略的示例如下:
    • MyTreeSet(简易实现,针对mutable.TreeSet封装)
      import scala.collection.mutable
      
      class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {
      
        val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)
      
        def +=(elem: A): MyTreeSet[A] = {
          this add elem
      
          this
        }
      
        def add(elem: A): Unit = {
          set.add(elem)
      
          // 删除排在最后的多余元素
          check10Size()
        }
      
        def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
          that.set.foreach(e => this add e)
      
          this
        }
      
        def check10Size(): Unit = {
          // 如果超过了firstNum个,就删除
          if (set.size > firstNum) {
            set -= set.last
          }
        }
      
        override def toString: String = set.toString
      }
      
      object MyTreeSet {
      
        def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem) // 默认保留前10
        
        def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
        
      }
      
    • Spark部分代码
        val resultRDD = studentDF.rdd
            .map(row => {
              val id = row.getLong(row.fieldIndex("id"))
              val math = row.getInt(row.fieldIndex("math"))
              val year = row.getInt(row.fieldIndex("year"))
      
              (year, new Person(id, math))
            }).aggregateByKey(MyTreeSet[Person](100)) (
                (set, v) => set += v,
                (set1, set2) => set1 ++= set2
            )
      

2.4 最终代码,以及其他附件代码

  • 最终代码
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
    
    import scala.collection.immutable.TreeSet
    
    object GroupDemo3 {
    
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setAppName("GroupDemo")
        val spark = SparkSession.builder()
          .config(conf)
          .enableHiveSupport()
          .getOrCreate()
    
        val tbSchema = StructType(Array(
          StructField("id", LongType, true),
          StructField("chinese", IntegerType, true),
          StructField("math", IntegerType, true),
          StructField("english", IntegerType, true),
          StructField("year", IntegerType, true)
        ))
    
        // 获取原始数据
        val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
          .where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
    
        // 开始进行分析
        val resultRDD = studentDF.rdd
          .map(row => {
            val id = row.getLong(row.fieldIndex("id"))
            val math = row.getInt(row.fieldIndex("math"))
            val year = row.getInt(row.fieldIndex("year"))
    
            (year, new Person(id, math))
          }).aggregateByKey(MyTreeSet[Person](100)) (
              (set, v) => set += v,
              (set1, set2) => set1 ++= set2
          ) // 依次合并2个Set,并只保留前100
    
        // 触发Action,展示部分统计结果
        resultRDD.take(10).foreach(println)
    
        spark.stop()
      }
    
    }
    
  • Person的Scala实现
    class PersonScala(val id: Long, val math: Int) extends Ordered[PersonScala]  with Serializable {
    
      override def compare(that: PersonScala): Int = {
        var result = that.math - this.math // 降序
        if (result == 0)
          result = if (that.id - this.id > 0) 1 else -1
        result
      }
    
      override def equals(obj: Any): Boolean = {
        obj match {
          case person: PersonScala => this.id == person.id
          case _ => false
        }
      }
    
      override def hashCode(): Int = (id ^ (id >>> 32)).toInt
    
      override def toString: String = "Person{" + "id=" + id + ", math=" + math + '}'
    
    }
    
    object PersonScala {
    
      def apply(id: Long, math: Int): PersonScala = new PersonScala(id, math)
    
    }
    
  • 示例——使用SQL获取历年数学的前100名(简单,但性能一般,且存在数据倾斜的可能)
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setAppName("GroupDemo")
        val spark = SparkSession.builder()
          .config(conf)
          .enableHiveSupport()
          .getOrCreate()
        
        // 使用sql分析
        val resultDF = spark.sql(
          """
            |SELECT year,id,math
            |FROM (
            | SELECT year,id,math,ROW_NUMBER() OVER (PARTITION BY year ORDER BY math DESC) rank 
            | FROM tb_student_score
            |) g
            |WHERE g.rank <= 100
          """.stripMargin)
        
        // 触发Action,展示部分统计结果
        resultDF.show()
    
        spark.stop()
      }
    

你可能感兴趣的:(Scala,BigData,#,Spark)