该文章例子pyflink环境是apache-flink==1.13.6
Python 自定义函数是 PyFlink Table API 中最重要的功能之一,其允许用户在 PyFlink Table API 中使用 Python 语言开发的自定义函数,极大地拓宽了 Python Table API 的使用范围。
简单来说 就是有的业务逻辑和需求是sql语句满足不了或太麻烦的,需要用过函数来实现。
Python UDF,即 Python ScalarFunction,针对每一条输入数据,仅产生一条输出数据。
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)
table = t_env.from_elements([("hello&11", 1), ("world&22", 2), ("flink&33", 3)], ['a', 'b'])
#方式一:
#result_type 是输出类型,如果是多个返回值,则需写result_types
#同理也可以指定输入类型,input_type,多个返回值写input_types
@udf(result_type=DataTypes.STRING())
def sub_string(s: str, begin: int, end: int):
return s[begin:end]
@udf(result_type=DataTypes.STRING())
def split_t(s):
ab = s.split('&')
return ab[0]
result = table.select(split_t(table.a).alias('a'))
# 方式二:
sub_string_lambda_fun = udf(lambda s, begin, end: s[begin:end], result_type=DataTypes.STRING())
result = table.select(sub_string_lambda_fun(table.a, 1, 3))
方式三:
# 继承ScalarFunction
# 实现eval方法,来实现方法
# open在初始化时执行一次,跟java的富函数一样,比如需要全局执行一次的(mysql连接等),可以放在open方法中执行,
# 可以注册Metrics对象 https://nightlies.apache.org/flink/flink-docs-release-1.16/docs/dev/python/table/metrics/
class SubString(ScalarFunction):
def open(self, function_context):
#super().open(function_context)
#self.counter = function_context.get_metric_group().counter("my_counter")
pass
def eval(self, s: str, begin: int, end: int):
return s[begin:end]
sub_string = udf(SubString(), result_type=DataTypes.STRING())
result.execute().print()#直接打印
# result = result.to_pandas() ##这里可以转换成pandas
# 也可以用with遍历
+----+--------------------------------+
| op | a |
+----+--------------------------------+
| +I | hello |
| +I | world |
| +I | flink |
+----+--------------------------------+
Python UDTF,即 Python TableFunction,针对每一条输入数据,Python UDTF 可以产生 0 条、1 条或者多条输出数据,此外,一条输出数据可以包含多个列。比如以下示例,定义了一个名字为 split 的Python UDF,以指定字符串为分隔符,将输入字符串切分成两个字符串:
from pyflink.table.udf import udtf
from pyflink.table import DataTypes
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)
table = t_env.from_elements([("hello&11", 1), ("world&22", 2), ("flink&33", 3)], ['a', 'b'])
@udtf(result_types=[DataTypes.STRING(), DataTypes.STRING()])
def split(s: str, sep: str):
splits = s.split(sep)
yield splits[0], splits[1]
#合并两个结果集
#可以使用左、右和内等连接查询
#result = table.join_lateral(split(table.a, '&'))
result = table.left_outer_join_lateral(split(table.a, '&'))
result.execute().print()
+----+--------------------------------+----------------------+--------------------------------+--------------------------------+
| op | a | b | f0 | f1 |
+----+--------------------------------+----------------------+--------------------------------+--------------------------------+
| +I | hello&11 | 1 | hello | 11 |
| +I | world&22 | 2 | world | 22 |
| +I | flink&33 | 3 | flink | 33 |
+----+--------------------------------+----------------------+--------------------------------+--------------------------------+
Python UDAF,即 Python AggregateFunction。Python UDAF 用来针对一组数据进行聚合运算,比如同一个 window 下的多条数据、或者同一个 key 下的多条数据等。针对同一组输入数据,Python AggregateFunction 产生一条输出数据。比如以下示例,定义了一个名字为 weighted_avg 的 Python UDAF:
from pyflink.common import Row
from pyflink.table import AggregateFunction, DataTypes, EnvironmentSettings, StreamTableEnvironment
from pyflink.table.udf import udaf
class WeightedAvg(AggregateFunction):
## ImperativeAggregateFunction 类需要实现的抽象类
def create_accumulator(self):
print("111")
# Row(sum, count)
return Row(0, 0)
# AggregateFunction 类 需要实现的抽象类
def get_value(self, retract) -> float:
if retract[1] == 0:
return 0
else:
return retract[0] / retract[1]
## ImperativeAggregateFunction 类需要实现的抽象类
# 累加器方法
def accumulate(self, accumulator, value, weight):
print(value, weight)
accumulator[0] += value * weight
accumulator[1] += weight
# 缩减方法,这个不需要必须实现
def retract(self, accumulator: Row, value, weight):
accumulator[0] -= value * weight
accumulator[1] -= weight
weighted_avg = udaf(f=WeightedAvg(),
result_type=DataTypes.DOUBLE(),
accumulator_type=DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.BIGINT()),
DataTypes.FIELD("f1", DataTypes.BIGINT())]))
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().use_blink_planner().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)
t = t_env.from_elements([(1, 2, "Lee"), (3, 4, "Jay"), (5, 6, "Jay"), (7, 8, "Lee")],
["value", "count", "name"])
result = t.group_by(t.name).select(weighted_avg(t.value, t.count).alias("avg"))
result.execute().print()
+----+--------------------------------+
| op | avg |
+----+--------------------------------+
| +I | 5.8 |
| +I | 4.2 |
+----+--------------------------------+
Python UDTAF,即 Python TableAggregateFunction。Python UDTAF 用来针对一组数据进行聚合运算,比如同一个 window 下的多条数据、或者同一个 key 下的多条数据等,与 Python UDAF 不同的是,针对同一组输入数据,Python UDTAF 可以产生 0 条、1 条、甚至多条输出数据。
from pyflink.common import Row
from pyflink.table import DataTypes, EnvironmentSettings, StreamTableEnvironment
from pyflink.table.udf import udtaf, TableAggregateFunction
class Top2(TableAggregateFunction):
def create_accumulator(self):
# 存储当前最大的两个值
return [None, None]
def accumulate(self, accumulator, input_row):
if input_row[0] is not None:
# 新的输入值最大
if accumulator[0] is None or input_row[0] > accumulator[0]:
accumulator[1] = accumulator[0]
accumulator[0] = input_row[0]
# 新的输入值次大
elif accumulator[1] is None or input_row[0] > accumulator[1]:
accumulator[1] = input_row[0]
def emit_value(self, accumulator):
yield Row(accumulator[0])
if accumulator[1] is not None:
yield Row(accumulator[1])
top2 = udtaf(f=Top2(),
result_type=DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())]),
accumulator_type=DataTypes.ARRAY(DataTypes.BIGINT()))
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().use_blink_planner().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)
t = t_env.from_elements([(1, 'Hi', 'Hello'),
(3, 'Hi', 'hi'),
(5, 'Hi2', 'hi'),
(2, 'Hi', 'Hello'),
(7, 'Hi', 'Hello')],
['a', 'b', 'c'])
t_env.execute_sql("""
CREATE TABLE my_sink (
word VARCHAR,
`sum` BIGINT
) WITH (
'connector' = 'print'
)
""")
result = t.group_by(t.b).flat_aggregate(top2).select("b, a").execute_insert("my_sink")
# 1)等待作业执行结束,用于local执行,否则可能作业尚未执行结束,该脚本已退出,会导致minicluster过早退出
# 2)当作业通过detach模式往remote集群提交时,比如YARN/Standalone/K8s等,需要移除该方法
result.wait()