pyspark-03 UDF和Pandas_UDF

目录

  • udf
  • pandas_udf


udf

pyspark里面常常需要自定义函数进行数据处理

  1. udf是针对一行数据进行处理
  2. pandas_udf是针对一个series进行处理
  3. udfa是针对groupby之后的数据进行处理

应用类似于pandas里面的apply函数,

import numpy as np
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, FloatType

pandas_df = pd.DataFrame({'row_number': np.arange(100000)})
spark_df = spark.createDataFrame(pandas_df) 
spark_df.show(5)
+----------+
|row_number|
+----------+
|         0|
|         1|
|         2|
|         3|
|         4|
+----------+

# 定义udf函数,开根号
@F.udf(returnType=FloatType())
def cacl_sqrt(x):
    return float(np.sqrt(x))
 spark_df.withColumn('a2', cacl_sqrt('row_number')).show()
 
 +----------+---------+
|row_number|       a2|
+----------+---------+
|         0|      0.0|
|         1|      1.0|
|         2|1.4142135|
|         3|1.7320508|
|         4|      2.0|

  • udf传入的参数x,就是row_number列里每一行数据
  • udf默认返回是string,returnType建议指定好返回类型
  • udf不支持numpy格式数据,需要转为基本类型float(np.sqrt(x))

聚合后结合collect_list使用udf,实现udaf功能

rdd = spark.sparkContext.parallelize([
    (1, 2., 'sdsd|sdsd:sdsd', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
    (3, 4., '', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))
])
df = spark.createDataFrame(rdd, 
                           schema=['a', 'b', 'c', 'd', 'e'])
df.show()
+---+---+--------------+----------+-------------------+
|  a|  b|             c|         d|                  e|
+---+---+--------------+----------+-------------------+
|  1|2.0|sdsd|sdsd:sdsd|2000-01-01|2000-01-01 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  1|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  2|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  2|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  2|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  2|3.0|   20202_19001|2000-02-01|2000-01-02 12:00:00|
|  3|4.0|              |2000-03-01|2000-01-03 12:00:00|
+---+---+--------------+----------+-------------------+

# 根据a聚合,求b的均值
# 1. 聚合结果
df.groupby('a').agg(F.collect_list('b').alias('new_b')).show()
+---+--------------------+
|  a|               new_b|
+---+--------------------+
|  1|[2.0, 3.0, 3.0, 3...|
|  2|[3.0, 3.0, 3.0, 3.0]|
|  3|               [4.0]|
+---+--------------------+

# 2. 求均值
@F.udf(returnType=FloatType())
def calc(x_list):
    return float(np.mean(x_list))
df.groupby('a').agg(calc(F.collect_list('b')).alias('new_b')).show()
+---+-----+
|  a|new_b|
+---+-----+
|  1|  2.9|
|  2|  3.0|
|  3|  4.0|
+---+-----+
  • 聚合后的collect_list,和hive里面一样的效果,转为list。因此传入udf的参数,就是一个list

pandas_udf

主要是针对一个series处理,因此在处理每一行数据的时候,可以得到一整个列的统计信息

@F.pandas_udf(FloatType())
def cacl_sqrt2(x):
    return np.sqrt(x)

spark_df.withColumn('a4', cacl_sqrt2('row_number')).show()
+----------+---------+
|row_number|       a3|
+----------+---------+
|         0|      0.0|
|         1|      1.0|
|         2|1.4142135|
|         3|1.7320508|
|         4|      2.0|


# 计算row_number每个元素和row_number列平均值的差
@F.pandas_udf(FloatType())
def cacl_sqrt2(x):
    return x - np.mean(x)
spark_df.withColumn('a1', cacl_sqrt2('row_number')).show(5)
+----------+-------+
|row_number|     a1|
+----------+-------+
|         0|-4999.5|
|         1|-4998.5|
|         2|-4997.5|
|         3|-4996.5|
|         4|-4995.5|
+----------+-------+
  • pandas_udf里面的参数x,就是row_number一整列,以pandas.series形式传入,整体处理
  • 返回类型是指series里面元素的类型
  • “计算row_number每个元素和row_number列平均值的差”, 如果用udf,得先计算均值,然后作为参数传入udf计算差值

求row_number列的平方

@F.pandas_udf(FloatType())
def cacl_sqrt3(a):
    return pd.Series([i**2 for i in range(len(a))])
    
spark_df.withColumn('a4', cacl_sqrt3('row_number')).show()
+----------+-----+
|row_number|   a4|
+----------+-----+
|         0|  0.0|
|         1|  1.0|
|         2|  4.0|
|         3|  9.0|
|         4| 16.0|
  • 这里是展示如果自定义处理,返回的数据必须以series的格式返回

pandas_udf比udf有较大优势的还有在模型预测上,比如训练一个xgboost模型,没必要用pyspark来写,流程复杂,变量需要指定各种Index,Vector等。可以用下面的流程:

  1. pyspark做数据处理和特征工程,参考 画像分析
  2. 单机训练模型, 基本的xgboost训练, 保存为.pkl文件, 参考xgboost官网
  3. 上传模型到集群pyspark预估,参考如下代码:
# 假设训练一个模型预测某个用户的性别

@F.pandas_udf(returnType=FloatType())
def pred_gender_proba(*feature_list):
    x_data = pd.concat(feature_list, axis=1).values
    y_pred = model.predict(x_data, num_iteration=model.best_iteration)
    return pd.Series(y_pred)

features = ['feat1', 'feat2', 'feat3']
df = df.withColumn('gender_prob', pred_gender_proba(*features))
  • 利用pandas_udf预估比udf效率会高很多

你可能感兴趣的:(pyspark记录,pandas,python,数据分析)