spark2.4 sql 快速列去重(冗余列检查)

        一直想做一个勤奋的人,笔耕方田,将自己在从事spark开发四年来积累的奇淫巧技分享出来。在给大家提供参考方案的同时也在总结和优化之前的设计。如果在有幸碰到大牛忍不住提出更好的优化方案能从中受益,也不枉码了这么多字。每当设计出一个很好的计算方案,就会忍不住打开博客想分享出来。然后会一直琢磨该以什么样的文字描述出来,想着想着就放弃了。总是想构思一个比较完美的结构段落,结果到最后什么也没有写出来。分享也就这样一直拖下去了。可能也有很多勤奋的人,和我一样毁于强迫症。

         回归正题

        背景:最近在做的项目,需要处理大量的列。需要对这些列进行去除冗余列,保留不相同的列。大概有15000+列左右。

        第一版设计

        使用了最原始的计算方式,两两组合进行比较。算倒是能够计算的出来,只是性能实在试太差了,并且时不时的会爆出codeGene buffer 超出了64k。调整了很久最终也能稳定的计算过去,只是这样计算实在试太慢。

        性能分析

       1.5万列,两两组合大概有15000*14999/2=112492500种组合,由于会边比较边丢失冗余列所以组合数大概在0.9亿左右。先不考虑行数,列组合比较就需要0.9亿组合,甚至列组合比行数还要多。这种计算实在是太消耗资源和时间!

        第二版设计

        先计算每列的非空count和基元个数(count distinct),只有count和基元个数相同的情况下,再去两两组合比较。虽然在前期计算非空count和(count distinct)消耗了不少时间,但是整体计算时间变成了原来的三分之一。

        性能分析

        如何计算1.5w列,每列的非空count和基元个数,也经历了两个版本的调优。后期会写博客分享如何只用三次shuffle计算1.5万列的count和基元数

       第三版设计

       从做第一版开始,就在思考一个问题:如何将hash或者hashset应用到这个场景中。使用hashSet或者HashMap可定会有一种更快的设计方案,可以快速算出冗余列。后来慢慢就设计出了一个方案,使用Set自动去重的特性设计出了这一版方案

       原理:使用UDAF,众所周知,在spark中编写UDAF需要实现比较重要的几个函数:

                 update :相当于AggregateByKey的map端的操作,将每条数据做缘生意转换并放到初始容器中

                 merge:相当于AggregateBykey的reduce端操作,做容器容器之间的合并

                 evaluate:最后转换容器中数据,返回最最终数据

                 inputSchema:输入列的数据类型

                 bufferSchema:中间容器的数据类型

                 dataType:最终返回数据的数据类型

                 initialize: 初始化容器

起始原理非常简单,首先使用spark sql内置函数,把需要比较的列放到array中array(colArray:_*).alia("tmp_arr"),然后调用编写的UDAF。

   UDAF update函数:

      伪代码:array.zipwithPartition.toMap.keySet.toArray

      描述:将传入的值和下标配接在一起,然后把值相同的去掉,最后只留当前row,全部不相同的列和列下标。由于spark sql不支持set数据结构,所以最终需要将数据转回array。由于在scala中toMap是将array中数据依次添加到Map中,所有后面的数据如果和前面的相同,后面的值和下标会覆盖已经出现过的。最终返回全部不一样的列的下标。

    UDAF merge函数:

     在merge函数中只需要将传过来的下标数组取交集就可以了

     这样最终返回的就是全部不相同的列下标数组

     计算过程演进:

         spark2.4 sql 快速列去重(冗余列检查)_第1张图片

如下为代码:

class ChckUDAF extends UserDefinedAggregateFunction {
  val logger = LoggerFactory.getLogger(getClass)

  override def inputSchema: StructType = StructType(Array(
    StructField("keys", DataTypes.createArrayType(StringType))
  ))

  override def bufferSchema: StructType = StructType(Array(StructField("buff", DataTypes.createMapType(IntegerType, StringType))))

  override def dataType: DataType = DataTypes.createArrayType(IntegerType)

  override def deterministic: Boolean = false

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Map[String, String]())

  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val indexs: Map[Int, String] = input.getAs[Seq[String]](0)
      .zipWithIndex
      .filter(_._1 != null)
      .toMap.map(tup => (tup._2, null: String))
    if (indexs.nonEmpty) {
      val oldMap = buffer.getAs[Map[Int, String]](0)
      val result = oldMap ++ indexs
      buffer.update(0, result)
    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map2 = buffer2.getMap[Int, String](0)
    if (map2 != null && map2.nonEmpty) {
      val map1 = buffer1.getMap[Int, String](0)
      buffer1.update(0, map1 ++ map2)
    }
  }

  override def evaluate(buffer: Row): Any = buffer.getMap[Int, String](0).keySet.toArray
}

 

你可能感兴趣的:(大数据)