pyspark.sql.functions.pandas_udf(f=None, returnType=None, functionType=None)
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations
使用spark arrow传输数据,由pandas处理数据,由于使用pandas,所以可以进行一些向量化处理
参数解读
type | 说明 | 备注 |
---|---|---|
SCALAR | 单独处理 DataFrame 的每个元素。它采用一个或多个 pandas Series 作为输入,并返回一个 pandas Series。这种类型的 Pandas UDF 应用于 DataFrame 的 select 和 withColumn 方法。适用于 element-wise 操作 | default |
SCALAR_ITER | 类似于 SCALAR,但它是在迭代器上操作的,允许更有效地处理大型数据集 | - |
GROUPED_MAP | 用于分组操作,需要返回与输入 DataFrame 相同大小的 DataFrame。应用于 DataFrame 的 groupBy 和 apply 方法。适用于分组转换操作 | - |
GROUPED_AGG | 用于分组聚合操作,将一组值减少为一个标量值。应用于 DataFrame 的 groupBy 和 agg 方法。适用于分组聚合操作 | 还有一个和MAP的显著区别是,这个只支持一列作为输入,所以无法将整个pdf输入到UDF函数里 |
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder \
.appName("pandas_udf scaler Example") \
.getOrCreate()
# 创建虚拟数据集
g = np.tile(['group a','group b'], 10)
x = np.linspace(0, 10., 20)
np.random.seed(3) # set seed for reproducibility
y_lin = 2*x + np.random.rand(len(x))/10.
y_qua = 3*x**2 + np.random.rand(len(x))
df = pd.DataFrame({'group': g, 'x': x, 'y_lin': y_lin, 'y_qua': y_qua})
schema = StructType([
StructField('group', StringType(), nullable=False),
StructField('x', DoubleType(), nullable=False),
StructField('y_lin', DoubleType(), nullable=False),
StructField('y_qua', DoubleType(), nullable=False),
])
df = spark.createDataFrame(df, schema=schema)
+-------+------------------+-------------------+-------------------+
| group| x| y_lin| y_qua|
+-------+------------------+-------------------+-------------------+
|group a| 0.0|0.05507979025745755|0.28352508177131874|
|group b|0.5263157894736842| 1.123446361209179| 1.5241628490609185|
|group a|1.0526315789473684| 2.134353631786031| 3.7645534406624286|
|group b|1.5789473684210527| 3.2089774973618717| 7.6360921152062655|
|group a|2.1052631578947367| 4.299821011224239| 13.8410479099986|
|group b| 2.631578947368421| 5.352787203630186| 21.555938033209422|
|group a|3.1578947368421053| 6.328348004730595| 30.22326103930139|
+-------+------------------+-------------------+-------------------+
# 对一列进行操作
# series to series pandas UDF
@F.pandas_udf(DoubleType())
def standardise(col1: pd.Series) -> pd.Series:
return (col1 - col1.mean())/col1.std()
res = df.select(standardise(F.col('y_lin')).alias('result'))
res.show(5)
+-------------------+
| result|
+-------------------+
|-1.6054255151193093|
|-1.4337009540623533|
|-1.2712121491623172|
| -1.098481817986802|
|-0.9231444116198374|
+-------------------+
def standardise(col1: pd.Series) -> pd.Series:
return (col1 - col1.mean())/col1.std()
standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("y_lin_standard", standard_udf(F.col('y_lin')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+
| group| x| y_lin| y_qua| y_lin_standard|
+-------+------------------+-------------------+-------------------+-------------------+
|group a| 0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093|
|group b|0.5263157894736842| 1.123446361209179| 1.5241628490609185|-1.4337009540623533|
|group a|1.0526315789473684| 2.134353631786031| 3.7645534406624286|-1.2712121491623172|
+-------+------------------+-------------------+-------------------+-------------------+
def standardise(col1: pd.Series, col2: pd.Series) -> pd.Series:
return (col1 - col2.mean())/col1.std()
standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("ret", standard_udf(F.col('y_lin'), F.col('y_qua')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
| group| x| y_lin| y_qua| y_lin_standard| ret|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
|group a| 0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093| -16.57141348616838|
|group b|0.5263157894736842| 1.123446361209179| 1.5241628490609185|-1.4337009540623533|-16.399688925111427|
|group a|1.0526315789473684| 2.134353631786031| 3.7645534406624286|-1.2712121491623172| -16.23720012021139|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
# 官方
@pandas_udf("col1 string, col2 long")
def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:
s3['col2'] = s1 + s2.str.len()
return s3
# Create a Spark DataFrame that has three columns including a struct column.
df = spark.createDataFrame(
[[1, "a string", ("a nested string",)]],
"long_col long, string_col string, struct_col struct" )
df.show()
+--------+----------+-----------------+
|long_col|string_col| struct_col|
+--------+----------+-----------------+
| 1| a string|{a nested string}|
+--------+----------+-----------------+
df.select(func("long_col", "string_col", "struct_col").alias('ret')).show()
+--------------------+
| ret|
+--------------------+
|{a nested string, 9}|
+--------------------+
# 输出dataframe
@pandas_udf("first string, last string")
def split_expand(s: pd.Series) -> pd.DataFrame:
return s.str.split(expand=True)
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(split_expand("name")).show()
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, LongType
import pandas as pd
import re
spark = SparkSession.builder \
.appName("Pandas UDF Example") \
.getOrCreate()
def extract_numbers(series: pd.Series) -> pd.Series:
return series.apply(lambda x: int(re.sub(r'\D', '', x)) if re.sub(r'\D', '', x) else None)
@pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
def extract_numbers_udf(series_iter):
for series in series_iter:
yield extract_numbers(series)
data = [("abc123",), ("def456",), ("ghi789",), ("jkl0",)]
schema = "text STRING"
input_df = spark.createDataFrame(data, schema=schema)
result_df = input_df.select(extract_numbers_udf("text").alias("numbers"))
result_df.show()
+-------+
|numbers|
+-------+
| 123|
| 456|
| 789|
| 0|
+-------+
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def subtract_mean(pdf):
return pdf.assign(v=pdf.v - pdf.v.mean())
df.groupby('id').apply(subtract_mean)
# stype1
@F.pandas_udf(T.DoubleType())
def average_column(col1: pd.Series, col2: pd.Series) -> float:
return (col1 + col2).mean()
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))
# stype2
def average_(col1: pd.Series, col2: pd.Series) -> float:
return (col1 + col2).mean()
average_column = pandas_udf(average_, DoubleType(), PandasUDFType.GROUPED_AGG)
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))
show_frame(res)
# +-------+------------------------+
# |group |average of y_lin + y_qua|
# +-------+------------------------+
# |group a|104.770 |
# |group b|121.621 |
# +-------+------------------------+
限制1,自定义中需要传入函数,可以通过python的装饰器函数解决
def sum_pd(pdf):
v = pdf.v
return pdf.assign(c=v.sum())
sum_udf = pandas_udf(sum_pd, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()
+---+----+----+
| id| v| c|
+---+----+----+
| 1| 1.0| 3.0|
| 1| 2.0| 3.0|
| 2| 3.0|18.0|
| 2| 5.0|18.0|
| 2|10.0|18.0|
+---+----+----+
# 增加参数的例子
def sum_pd(pp):
def wrap(pdf):
v = pdf.v
return pdf.assign(c=v.sum() + pp)
return wrap
pp = 1
sum_p = sum_pd(pp)
sum_udf = pandas_udf(sum_p, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()
+---+----+----+
| id| v| c|
+---+----+----+
| 1| 1.0| 4.0|
| 1| 2.0| 4.0|
| 2| 3.0|19.0|
| 2| 5.0|19.0|
| 2|10.0|19.0|
+---+----+----+
对于限制4,首先需要声明的是是否支持多列输入是取决于函数本身,在我开始的例子中,由于入参是pdf,所以无法支持多列,此中情况下,可以引入StructType解决,将需要输入的列整合到struct中输入到UDF函数中;
当入参设定的就是多列时,是支持多列的,但是为了代码的简洁性,个人更加倾向于第一种写法
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType, struct
import pandas as pd
spark = SparkSession.builder \
.appName("Two Day Subtract Example") \
.getOrCreate()
data = [
(1, "2023-04-12", 2.0),
(1, "2023-04-11", 4.0),
(2, "2023-04-12", 3.0),
(2, "2023-04-11", 5.0)
]
columns = ["id", "dt", "v"]
df = spark.createDataFrame(data, columns)
def two_day_subtract(window_end, datapipe):
def wrap(pdf):
assert 0 < pdf.shape[0] <= 2
dt = pdf[0]
if pdf.shape[0] == 1:
return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)
else:
return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)
return wrap
class dd:
def __init__(self, stat_col, dt):
self.stat_col = stat_col
self.dt = dt
window_end = '2023-04-12'
datapipe = dd(stat_col='v', dt='dt')
idd = 'id'
substract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()
# 装饰器写法也可以
@pandas_udf(DoubleType())
def two_day_subtract(window_end, datapipe):
def wrap(pdf):
assert 0 < pdf.shape[0] <= 2
dt = pdf[0]
if pdf.shape[0] == 1:
return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)
else:
return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)
return wrap
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()
# 直接多列输入
def two_day_subtract(window_end, datapipe):
def wrap(s1, s2):
assert 0 < len(s1) <= 2
dt = s2[0]
if len(s1) == 1:
return s1 * (1 if dt == window_end else -1)
else:
return (s1[0] - s1[1]) * (1 if dt == window_end else -1)
return wrap
substract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(F.col('v'), F.col('dt')).alias("num"))
stat_df.show()
reference:
1. Documents
2. Pandas UDFs in PySpark
3. Blog on databricks