PyFlink自定义函数

PyFlink(Apache Flink 的 Python API)中,自定义函数分为三种主要类型:ScalarFunction(标量函数)、TableFunction(表函数)和 AggregateFunction(聚合函数)。这些自定义函数可以在 Flink 的 SQL 和 Table API 中使用,用于扩展 PyFlink 的内置功能,处理自定义的计算逻辑。

1. 安装 PyFlink

在开始之前,确保你的环境中已安装了 PyFlink:

pip install apache-flink

2. 自定义标量函数 (ScalarFunction)

ScalarFunction 是最常见的自定义函数。它接受多个输入并返回一个标量值,类似于 SQL 中的普通函数。

2.1 创建一个标量函数

我们可以通过继承 ScalarFunction 类来定义一个自定义标量函数。例如,定义一个计算两个数之和的标量函数:

from pyflink.table.udf import ScalarFunction, udf
from pyflink.table import EnvironmentSettings, TableEnvironment

class AddScalarFunction(ScalarFunction):
    def eval(self, a, b):
        return a + b

# 创建 Table Environment
env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)

# 注册并使用标量函数
add_func = udf(AddScalarFunction(), result_type='BIGINT')
table_env.create_temporary_system_function("add_func", add_func)

# 创建示例数据
table_env.execute_sql("""
    CREATE TEMPORARY VIEW input_table (a BIGINT, b BIGINT) AS 
    VALUES (1, 2), (3, 4), (5, 6)
""")

# 查询并使用自定义标量函数
result = table_env.sql_query("SELECT add_func(a, b) FROM input_table")
result.execute().print()
2.2 代码解析
  1. 创建标量函数: 继承 ScalarFunction 类,并实现 eval 方法。eval 方法接受多个参数并返回计算结果。
  2. 注册函数: 使用 udf() 包装自定义函数并注册到 TableEnvironment 中,以便在 SQL 查询中使用。
  3. 在 SQL 中使用自定义函数: 使用注册的函数名 add_func 在 SQL 中调用该自定义函数。

3. 自定义表函数 (TableFunction)

TableFunction 返回一个表,而不是单一值。它类似于 SQL 中的 LATERAL VIEWUNNEST 操作,允许将一行数据转换成多行输出。

3.1 创建一个表函数

我们可以通过继承 TableFunction 类来定义自定义表函数。例如,定义一个表函数,将字符串按逗号分割并返回多个值:

from pyflink.table.udf import TableFunction, udtf

class SplitTableFunction(TableFunction):
    def eval(self, text):
        for word in text.split(","):
            self.collect(word)

# 注册并使用表函数
split_func = udtf(SplitTableFunction(), result_types=['STRING'])
table_env.create_temporary_system_function("split_func", split_func)

# 创建示例数据
table_env.execute_sql("""
    CREATE TEMPORARY VIEW input_table (text STRING) AS 
    VALUES ('hello,world'), ('foo,bar,baz')
""")

# 查询并使用自定义表函数
result = table_env.sql_query("""
    SELECT text, word
    FROM input_table, LATERAL TABLE(split_func(text)) AS T(word)
""")
result.execute().print()
3.2 代码解析
  1. 创建表函数: 继承 TableFunction 类,并在 eval 方法中使用 self.collect() 收集每行数据的输出。
  2. 注册函数: 使用 udtf() 包装自定义表函数并注册到 TableEnvironment 中。
  3. 在 SQL 中使用表函数: 使用 LATERAL TABLE 在 SQL 查询中调用自定义表函数,从每行数据生成多个输出。

4. 自定义聚合函数 (AggregateFunction)

AggregateFunction 用于定义自定义的聚合逻辑,类似于 SQL 中的聚合函数(如 SUMCOUNT 等)。它接收多行输入并返回聚合结果。

4.1 创建一个聚合函数

我们可以通过继承 AggregateFunction 类来定义自定义聚合函数。例如,定义一个求平均值的聚合函数:

from pyflink.table.udf import AggregateFunction, udaf

class AvgAggregateFunction(AggregateFunction):

    class Accumulator:
        def __init__(self):
            self.sum = 0
            self.count = 0

    def get_value(self, accumulator):
        return accumulator.sum / accumulator.count if accumulator.count != 0 else 0

    def create_accumulator(self):
        return AvgAggregateFunction.Accumulator()

    def accumulate(self, accumulator, value):
        if value is not None:
            accumulator.sum += value
            accumulator.count += 1

# 注册并使用聚合函数
avg_func = udaf(AvgAggregateFunction(), result_type='DOUBLE', accumulator_type='ROW')
table_env.create_temporary_system_function("avg_func", avg_func)

# 创建示例数据
table_env.execute_sql("""
    CREATE TEMPORARY VIEW input_table (a BIGINT) AS 
    VALUES (1), (2), (3), (4), (5)
""")

# 查询并使用自定义聚合函数
result = table_env.sql_query("SELECT avg_func(a) FROM input_table")
result.execute().print()
4.2 代码解析
  1. 创建聚合函数: 继承 AggregateFunction 类。create_accumulator 用于创建累加器,accumulate 用于聚合数据,get_value 用于返回聚合结果。
  2. 定义累加器: 定义了一个 Accumulator 类来保存聚合的中间状态(例如总和和计数)。
  3. 注册函数: 使用 udaf() 包装自定义聚合函数并注册到 TableEnvironment 中。
  4. 在 SQL 中使用聚合函数: 在 SQL 查询中调用自定义聚合函数。

5. 完整示例

以下是一个完整的示例,展示了如何在一个 PyFlink 程序中定义并使用 ScalarFunctionTableFunctionAggregateFunction

from pyflink.table.udf import ScalarFunction, TableFunction, AggregateFunction, udf, udtf, udaf
from pyflink.table import EnvironmentSettings, TableEnvironment

# Scalar Function: 加法
class AddScalarFunction(ScalarFunction):
    def eval(self, a, b):
        return a + b

# Table Function: 按逗号分割字符串
class SplitTableFunction(TableFunction):
    def eval(self, text):
        for word in text.split(","):
            self.collect(word)

# Aggregate Function: 计算平均值
class AvgAggregateFunction(AggregateFunction):

    class Accumulator:
        def __init__(self):
            self.sum = 0
            self.count = 0

    def get_value(self, accumulator):
        return accumulator.sum / accumulator.count if accumulator.count != 0 else 0

    def create_accumulator(self):
        return AvgAggregateFunction.Accumulator()

    def accumulate(self, accumulator, value):
        if value is not None:
            accumulator.sum += value
            accumulator.count += 1

# 创建 Table Environment
env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)

# 注册自定义函数
add_func = udf(AddScalarFunction(), result_type='BIGINT')
table_env.create_temporary_system_function("add_func", add_func)

split_func = udtf(SplitTableFunction(), result_types=['STRING'])
table_env.create_temporary_system_function("split_func", split_func)

avg_func = udaf(AvgAggregateFunction(), result_type='DOUBLE', accumulator_type='ROW')
table_env.create_temporary_system_function("avg_func", avg_func)

# 示例数据
table_env.execute_sql("""
    CREATE TEMPORARY VIEW input_table (a BIGINT, text STRING) AS 
    VALUES (1, 'foo,bar'), (2, 'hello,world'), (3, 'foo,baz')
""")

# 使用标量函数
result = table_env.sql_query("SELECT add_func(a, a) FROM input_table")
result.execute().print()

# 使用表函数
result = table_env.sql_query("""
    SELECT text, word
    FROM input_table, LATERAL TABLE(split_func(text)) AS T(word)
""")
result.execute().print()

# 使用聚合函数
result = table_env.sql_query("SELECT avg_func(a) FROM input_table")
result.execute().print()

6. 总结

在 PyFlink 中,自定义函数(ScalarFunctionTableFunctionAggregateFunction)是扩展 Flink SQL 和 Table API 功能的重要工具。通过编写自定义函数,你可以将复杂的业务逻辑集成到 Flink 的数据处理管道中,从而实现更灵活、更强大的数据处理应用。

你可能感兴趣的:(pyflink,flink)