Spark SQL实现遍历带父子id的树状结构表数据,生成带层级关系的维表数据

Hive不支持递归CTE,但可以通过Spark SQL遍历出带层级关系的数据。

整体思路:

  1. 准备好源头数据,主要保留结点id和对应的父结点id。
  2. 获取根节点数据,定为第1级节点数据,保存该层级数据并做好标记。
  3. 以最新一级的节点数据,以结点自身id关联源头数据的父结点id,关联出来的数据定为新一级的数据,保存该层级数据并做好标记。
  4. 不断重复第3步骤,直到关联不出数据(关联出来的数据行数为0)。
  5. 以上步骤保存的数据,使用UNION ALL进行整合插入到表中。插入表时可以留意是否需要补充其他数据,如一些特殊值以及没能通过以上步骤遍历出来的数据。
  6. 为防止死循环,循环次数会设置遍历深度上限,超出上限时中止脚本运行并报错。

以组织架构数据为示例:

# -*- coding:UTF-8 -*-
# desc: 实现不可知深度的树状结构结点数据清洗
# spark version: 2.4.1
# python version: 2.7.5

import sys

from pyspark.sql import SparkSession

reload(sys)
sys.setdefaultencoding("utf-8")


def main():
    # 获取传入的第1个参数,当前环境传入的是yyyyMMdd的日期
    data_date = sys.argv[1]

    # 参数填充
    param = {
        "statdate": data_date
    }

    # 为了保证HiveQL迁移到SparkSQL时,时间转换不会出错,
    # 需要设置spark.sql.session.timeZone=UTC
    spark_session_init = SparkSession.builder \
        .config("spark.shuffle.service.enabled", "true") \
        .config("spark.dynamicAllocation.enabled", "true") \
        .config("spark.dynamicAllocation.schedulerBacklogTimeout", "20s") \
        .config("spark.dynamicAllocation.initialExecutors", "1") \
        .config("spark.dynamicAllocation.minExecutors", "1") \
        .config("spark.dynamicAllocation.maxExecutors", "2") \
        .config("spark.dynamicAllocation.executorIdleTimeout", "120s") \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.memoryOverhead", "4096") \
        .config("spark.executor.cores", "2") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.locality.wait", "500ms") \
        .config("spark.default.parallelism", 20) \
        .config("spark.sql.shuffle.partitions", 20) \
        .config("spark.sql.session.timeZone", "UTC") \
        .enableHiveSupport() \
        .getOrCreate()

    with spark_session_init as spark_session:
        def hive_sql(sql_string):
            sql_string_with_param = sql_string.format(**param)
            return spark_session.sql(sql_string_with_param)

        hive_sql(
            """
            select
                 id as org_id
                ,trim(name) as org_name
                ,trim(description) as org_desc
                ,parent_id
            from ods.ods_org_dt
            where statdate='{statdate}'
            """
        ).cache().createOrReplaceTempView("org_source")

        level_count = 1
        org_rdd = hive_sql(
            """
            select
                 org_id
                ,org_name
                ,org_desc
                ,0 as parent_id
                ,{now_level_count} as org_level
                ,cast(org_id as string) as org_id_path
                ,cast(org_name as string) as org_name_path
            from org_source
            where parent_id = 0
            """.format(**{"now_level_count": level_count, "last_level_count": (level_count - 1)})
        ).cache()

        level_item_count = org_rdd.count()
        org_level_table_name = "org_level_{now_level_count}".format(**{"now_level_count": level_count})
        org_rdd.createOrReplaceTempView(org_level_table_name)
        union_all_sql = "select org_id,org_name,org_desc,parent_id,org_level,org_id_path,org_name_path from " + org_level_table_name

        if level_item_count == 0:
            raise Exception("First level is empty.")

        # 若当前层级的数据量为0,跳出循环
        while True:
            level_count = level_count + 1

            # 限制遍历深度,以免出现死循环
            if level_count > 10:
                raise Exception("Level count > 10, please check if circle appears.")

            org_rdd = hive_sql(
                """
                select
                     m1.org_id
                    ,m1.org_name
                    ,m1.org_desc
                    ,m1.parent_id
                    ,{now_level_count} as org_level
                    ,concat(m2.org_id_path,',',cast(m1.org_id as string)) as org_id_path
                    ,concat(m2.org_name_path,',',cast(m1.org_name as string)) as org_name_path
                from org_source m1
                join org_level_{last_level_count} m2
                  on m1.parent_id = m2.org_id
                """.format(**{"now_level_count": level_count, "last_level_count": (level_count - 1)})
            ).cache()

            level_item_count = org_rdd.count()
            org_level_table_name = "org_level_{now_level_count}".format(**{"now_level_count": level_count})
            org_rdd.createOrReplaceTempView(org_level_table_name)

            # 当关联不出数据时,说明可遍历数据已不存在,终止遍历
            if level_item_count == 0:
                break

            union_all_sql = union_all_sql + " union all select org_id,org_name,org_desc,parent_id,org_level,org_id_path,org_name_path from " + org_level_table_name

        # 将梳理好的组织架构数据插入到表中,注意补充【未知组织】和没能遍历的组织数据
        hive_sql(
            """
            with step1 as
            (
                %s
            )
            ,step2 as
            (
            select
                 org_id as org_key
                ,org_id
                ,org_name
                ,org_desc
                ,parent_id
                ,org_level
                ,org_id_path
                ,org_name_path
            from step1
            union all
            select
                 org_id as org_key
                ,org_id
                ,org_name
                ,org_desc
                ,parent_id
                ,-9999 as org_level
                ,'' as org_id_path
                ,'' as org_name_path
            from org_source
            where not exists(select 1 from step1 where step1.org_id = org_source.org_id)
            union all
            select
                 -9999 as org_key
                ,null as org_id
                ,'【未知组织】' as org_name
                ,'对应不到真实组织id的数据,都将归为【未知组织】,统一维度展示时的称呼' as org_desc
                ,-9999 as parent_id
                ,-9999 as org_level
                ,'' as org_id_path
                ,'' as org_name_path
            )
            insert overwrite table test.dim_org_dt partition(statdate='{statdate}')
            select /*+ repartition(1) */
                 org_key
                ,org_id
                ,org_name
                ,coalesce(org_desc,'') as org_desc
                ,org_level
                ,coalesce(split(org_name_path,',')[0],'') as org_level_1_name
                ,coalesce(split(org_name_path,',')[1],'') as org_level_2_name
                ,coalesce(split(org_name_path,',')[2],'') as org_level_3_name
                ,coalesce(split(org_name_path,',')[3],'') as org_level_4_name
                ,coalesce(split(org_name_path,',')[4],'') as org_level_5_name
                ,coalesce(cast(split(org_id_path,',')[0] as bigint),-9999) as org_level_1_id
                ,coalesce(cast(split(org_id_path,',')[1] as bigint),-9999) as org_level_2_id
                ,coalesce(cast(split(org_id_path,',')[2] as bigint),-9999) as org_level_3_id
                ,coalesce(cast(split(org_id_path,',')[3] as bigint),-9999) as org_level_4_id
                ,coalesce(cast(split(org_id_path,',')[4] as bigint),-9999) as org_level_5_id
                ,parent_id as org_parent_id
                ,org_name_path
                ,org_id_path
            from step2
            """ % union_all_sql
        )


if __name__ == "__main__":
    main()

你可能感兴趣的:(大数据,数据库与SQL)