pyspark里面常常需要自定义函数进行数据处理
应用类似于pandas里面的apply函数,
import numpy as np
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, FloatType
pandas_df = pd.DataFrame({'row_number': np.arange(100000)})
spark_df = spark.createDataFrame(pandas_df)
spark_df.show(5)
+----------+
|row_number|
+----------+
| 0|
| 1|
| 2|
| 3|
| 4|
+----------+
# 定义udf函数,开根号
@F.udf(returnType=FloatType())
def cacl_sqrt(x):
return float(np.sqrt(x))
spark_df.withColumn('a2', cacl_sqrt('row_number')).show()
+----------+---------+
|row_number| a2|
+----------+---------+
| 0| 0.0|
| 1| 1.0|
| 2|1.4142135|
| 3|1.7320508|
| 4| 2.0|
聚合后结合collect_list使用udf,实现udaf功能
rdd = spark.sparkContext.parallelize([
(1, 2., 'sdsd|sdsd:sdsd', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(1, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(2, 3., '20202_19001', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
(3, 4., '', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))
])
df = spark.createDataFrame(rdd,
schema=['a', 'b', 'c', 'd', 'e'])
df.show()
+---+---+--------------+----------+-------------------+
| a| b| c| d| e|
+---+---+--------------+----------+-------------------+
| 1|2.0|sdsd|sdsd:sdsd|2000-01-01|2000-01-01 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 1|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 2|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 2|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 2|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 2|3.0| 20202_19001|2000-02-01|2000-01-02 12:00:00|
| 3|4.0| |2000-03-01|2000-01-03 12:00:00|
+---+---+--------------+----------+-------------------+
# 根据a聚合,求b的均值
# 1. 聚合结果
df.groupby('a').agg(F.collect_list('b').alias('new_b')).show()
+---+--------------------+
| a| new_b|
+---+--------------------+
| 1|[2.0, 3.0, 3.0, 3...|
| 2|[3.0, 3.0, 3.0, 3.0]|
| 3| [4.0]|
+---+--------------------+
# 2. 求均值
@F.udf(returnType=FloatType())
def calc(x_list):
return float(np.mean(x_list))
df.groupby('a').agg(calc(F.collect_list('b')).alias('new_b')).show()
+---+-----+
| a|new_b|
+---+-----+
| 1| 2.9|
| 2| 3.0|
| 3| 4.0|
+---+-----+
主要是针对一个series处理,因此在处理每一行数据的时候,可以得到一整个列的统计信息
@F.pandas_udf(FloatType())
def cacl_sqrt2(x):
return np.sqrt(x)
spark_df.withColumn('a4', cacl_sqrt2('row_number')).show()
+----------+---------+
|row_number| a3|
+----------+---------+
| 0| 0.0|
| 1| 1.0|
| 2|1.4142135|
| 3|1.7320508|
| 4| 2.0|
# 计算row_number每个元素和row_number列平均值的差
@F.pandas_udf(FloatType())
def cacl_sqrt2(x):
return x - np.mean(x)
spark_df.withColumn('a1', cacl_sqrt2('row_number')).show(5)
+----------+-------+
|row_number| a1|
+----------+-------+
| 0|-4999.5|
| 1|-4998.5|
| 2|-4997.5|
| 3|-4996.5|
| 4|-4995.5|
+----------+-------+
求row_number列的平方
@F.pandas_udf(FloatType())
def cacl_sqrt3(a):
return pd.Series([i**2 for i in range(len(a))])
spark_df.withColumn('a4', cacl_sqrt3('row_number')).show()
+----------+-----+
|row_number| a4|
+----------+-----+
| 0| 0.0|
| 1| 1.0|
| 2| 4.0|
| 3| 9.0|
| 4| 16.0|
pandas_udf比udf有较大优势的还有在模型预测上,比如训练一个xgboost模型,没必要用pyspark来写,流程复杂,变量需要指定各种Index,Vector等。可以用下面的流程:
# 假设训练一个模型预测某个用户的性别
@F.pandas_udf(returnType=FloatType())
def pred_gender_proba(*feature_list):
x_data = pd.concat(feature_list, axis=1).values
y_pred = model.predict(x_data, num_iteration=model.best_iteration)
return pd.Series(y_pred)
features = ['feat1', 'feat2', 'feat3']
df = df.withColumn('gender_prob', pred_gender_proba(*features))