版权声明:本文为博主原创文章,未经博主允许不得转载。
背景:需要对DataFrame中部分字段聚合,再通过udf对聚合的字段列表进行处理,返回列表,再把返回的列表字段列转行,如下:
group_id | feature_1 | feature_2 | feature_3 |
---|---|---|---|
1 | 1.11 | 1.21 | 1.31 |
1 | 1.12 | 1.22 | 1.32 |
2 | 2.11 | 2.21 | 2.31 |
2 | 2.12 | 2.22 | 2.32 |
按 “group_id” 聚合(如:[{“feature_1”: 1.11, “feature_2”: 1.21, “feature_3”: 1.31}, {“feature_1”: 1.12, “feature_2”: 1.22, “feature_3”: 1.32}]),再通过udf对聚合的字段列表进行处理,返回列表[{“feature_1”: 11.1, “feature_2”: 11.2, “feature_3”: 11.3}, {“feature_1”: 12.1, “feature_2”: 12.2, “feature_3”: 12.3}, {“feature_1”: 13.1, “feature_2”: 13.2, “feature_3”: 13.3}],再把返回的列表字段列转行,最终得到结果如下:
group_id | feature_1 | feature_2 | feature_3 |
---|---|---|---|
1 | 11.1 | 11.2 | 11.3 |
1 | 12.1 | 12.2 | 12.3 |
1 | 13.1 | 13.2 | 13.3 |
2 | 22.11 | 22.21 | 22.31 |
代码如下:
from pyspark.sql import functions as F
from pyspark.sql.functions import explode
from pyspark.sql.types import MapType, StringType, ArrayType, DoubleType
data_before = [
(1, 1.11, 1.21, 1.31),
(1, 1.12, 1.22, 1.32),
(2, 2.11, 2.21, 2.31),
(2, 2.12, 2.22, 2.32)
]
df = spark.createDataFrame(data_before, schema=['group_id', 'feature_1', 'feature_2', 'feature_3'])
def get_feature_dict(feature_1, feature_2, feature_3):
# 聚合为dict
feature_dict = dict()
feature_dict['feature_1'] = feature_1
feature_dict['feature_2'] = feature_2
feature_dict['feature_3'] = feature_3
return feature_dict
# 转换dict为MapType(StringType(), DoubleType())
udf_feature_dict = F.udf(get_feature_dict, returnType=MapType(StringType(), DoubleType()))
def calc_feature(group_id, feature_list):
'''
:return: List
'''
if group_id == 1:
return [{"feature_1": 11.1, "feature_2": 11.2, "feature_3": 11.3},
{"feature_1": 12.1, "feature_2": 12.2, "feature_3": 12.3},
{"feature_1": 13.1, "feature_2": 13.2, "feature_3": 13.3}]
else:
return [{"feature_1": 22.11, "feature_2": 22.21, "feature_3": 22.31}]
# 转换List 为ArrayType(MapType(StringType(), DoubleType()))
udf_calc = F.udf(calc_feature, returnType=ArrayType(MapType(StringType(), DoubleType())))
df_calc = df.withColumn("feature_dict", udf_feature_dict(F.col("feature_1"), F.col("feature_2"), F.col("feature_3"))) \
.groupBy('group_id') \
.agg(F.collect_list("feature_dict").alias("feature_dict_list")) \
.withColumn("calc_feature_list", udf_calc(F.col("group_id"), F.col("feature_dict_list"))) \
.withColumn('calc_feature_detail', explode("calc_feature_list")) \
.select('group_id',
F.col("calc_feature_detail")["feature_1"].alias("feature_1"),
F.col("calc_feature_detail")["feature_2"].alias("feature_2"),
F.col("calc_feature_detail")["feature_3"].alias("feature_3")
)
--------------------------文档信息--------------------------
版权声明:本文为博主原创文章,未经博主允许不得转载
署名(BY) :dkjkls(dkj卡洛斯)
文章出处:http://blog.csdn.net/dkjkls