pyspark数据倾斜问题解决-repartition & mapPartitions

在一个涉及到计算180天各类目、店铺、SKU的浏览数据的项目中,因为类目、店铺、SKU有用户活跃度的大差异存在,计算的时候遇到了严重的数据倾斜的情况。如下:

 之前关于为什么会数据倾斜、怎么判断数据是否倾斜的博客,可参考:

Spark处理数据倾斜问题_Just Jump的博客-CSDN博客_spark数据倾斜

 

为解决这个问题,考虑了几种方法,通过实验测试,但最终还是使用了repartition + mapPartitions 来解决。

pyspark  mapPartition  使用方法示例:

PySpark mapPartitions() Examples - Spark By {Examples}

# 官方提供的代码示例
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [('James','Smith','M',3000),
  ('Anna','Rose','F',4100),
  ('Robert','Williams','M',6200), 
]

columns = ["firstname","lastname","gender","salary"]
df = spark.createDataFrame(data=data, schema = columns)
df.show()

#Example 1 mapPartitions()
def reformat(partitionData):
    for row in partitionData:
        yield [row.firstname+","+row.lastname,row.salary*10/100]
df2=df.rdd.mapPartitions(reformat).toDF(["name","bonus"])
df2.show()

#Example 2 mapPartitions()
def reformat2(partitionData):
  updatedData = []
  for row in partitionData:
    name=row.firstname+","+row.lastname
    bonus=row.salary*10/100
    updatedData.append([name,bonus])
  return iter(updatedData)

df2=df.rdd.mapPartitions(reformat).toDF(["name","bonus"])
df2.show())

个人实践实例:

按季度、半年把浏览过商城各级类目的用户做去重统计。有些类目是热门类目,访问人数很多、频次也高,有些类目是需求相对低频的类目,访问人数相对少。如果直接按照类目分组,对用户去重统计就容易出现数据倾斜的情况,如上截图所示。结合repartition + mapPartitions(+udf) 来解决统计时的数据倾斜方法:

(1)repartition( ) 按照容易引发倾斜的列 及接下来计算会用到的分组列进行重新分区。注意同时调大这两个参数,并行度和shuffle分区数,这样才能将分区数调大。

--conf spark.sql.shuffle.partitions=10000 \
--conf spark.default.parallelism=10000 \

(2)使用mapPartition先在每个partition内做一次计算,之后再对此中间结果进行最终的聚合计算。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time: 2022/12/22 4:41 下午
# @Author: TobyGao

import numpy as np
from utils import getNdays, getNMonth, filterUdf, getTaskSavePath,\
    keep_window, get_output_table
from pyspark.sql.functions import lit, col, expr, sum, count
from pyspark.sql.types import *
import math


class ViewUsersAnalysis:
    def __init__(self, spark, task_name, calc_date):
        self.spark = spark
        self.task_name = task_name
        self.calc_date = calc_date

        self.keep_window = keep_window
        self.__output_table = get_output_table(task_name)
        self.__view_data_schema = StructType([
            StructField("user_id", StringType(), False),
            StructField("first_cate_cd", StringType(), False),
            StructField("first_cate_name", StringType(), False),
            StructField("second_cate_cd", StringType(), False),
            StructField("second_cate_name", StringType(), False),
            StructField("third_cate_cd", StringType(), False),
            StructField("third_cate_name", StringType(), False),
            StructField("shop_id", StringType(), False),
            StructField("shop_name", StringType(), False),
            StructField("main_sku_id", StringType(), False),
            StructField("sku_name", StringType(), False),
            StructField("request_date", StringType(), False)])

        self.spark.udf.register("filterUdf", filterUdf)
    
    def __load_view_data(self, window_type):
        start_date = getNdays(self.calc_date, -window_type)
        action_info = self.spark.createDataFrame(
            self.spark.sparkContext.emptyRDD(), self.__view_data_schema)

        data_path_list = []
        for n in range(math.ceil(window_type / 30) + 1):
            data_path_list.append(getTaskSavePath(self.task_name,getNMonth(self.calc_date,
                                                                -n) + "-*"))

        for file_path in data_path_list:
            print("input file paths:", file_path)
            action_info = action_info.unionAll(self.spark.read.format("parquet").load(file_path))

        return action_info


    @staticmethod
    def __partition_agg_func_cate1(row_list):
        cate_code_dict = dict()
        res_cate_dict = dict()
        res_dict = dict()
        result = []
        for row in row_list:
            user_id = row.user_id
            first_cate_cd = row.first_cate_cd
            first_cate_name = row.first_cate_name

            if first_cate_cd not in cate_code_dict.keys():
                cate_code_dict[first_cate_cd] = first_cate_name

            cate_key = (user_id, first_cate_cd)
            if cate_key not in res_cate_dict.keys():
                res_cate_dict[cate_key] = 1

        for k in res_cate_dict.keys():
            agg_key = k[1]
            if agg_key in res_dict.keys():
                res_dict[agg_key] += 1
            else:
                res_dict[agg_key] = 1
        res_cate_dict.clear()

        for k in res_dict.keys():
            result.append([k, cate_code_dict[k], res_dict[k]])
        res_dict.clear()
        cate_code_dict.clear()
        return iter(result)

    
    def get_count_results_in_long_window(self, action_info, window_type):
        action_info.cache()
        view_cate1 = action_info.repartition(10000, col("user_id"), col("first_cate_cd")) \
            .rdd.mapPartitions(self.__partition_agg_func_cate1) \
            .toDF(["first_cate_cd", "first_cate_name", "user_cnt_first_cate_view"]) \
            .groupBy("first_cate_cd", "first_cate_name") \
            .agg(sum("user_cnt_first_cate_view").alias("user_cnt_first_cate_view"))

        view_cate1.cache()
        view_cate1.show(20)
        

        view_shop = action_info.repartition(10000, col("user_id"),
                                                col("shop_id")) \
            .dropDuplicates().groupBy("first_cate_cd",
                                      "first_cate_name",
                                      "second_cate_cd",
                                      "second_cate_name",
                                      "third_cate_cd",
                                      "third_cate_name",
                                      "shop_id",
                                      "shop_name") \
            .agg(count("user_id").alias("user_cnt_shop_view")) 

        view_shop.show()
            

    def get_count_results_in_windows(self):
        for window_type in [60, 90]:
            # self.spark.sql(
            #     """alter table {output_table} drop partition (dt='{calc_date}', window_type='{window_type}')""".format(
            #         output_table=self.__output_table, calc_date=self.calc_date, window_type=window_type))
            action_info = self.__load_view_data(window_type).drop("request_date").withColumn(
                "window_type", lit(window_type))
            self.get_count_results_in_long_window(action_info, window_type)

此方法的局限是,内存消耗比较大,最好每个partition中需要临时存储计算的key的个数不要太大。但很好的解决了数据倾斜问题。

你可能感兴趣的:(spark,python,mapPartitions,repartition,分布式,数据倾斜)