Flink 使用之 SQL UDF

Flink 使用介绍相关文档目录

Flink 使用介绍相关文档目录

简介

在使用纯Flink SQL的场景下,对于复杂业务逻辑,Flink提供的内置fucntion是无法满足要求的。我们需要实现自定义的function,来扩充Flink的功能。用户自己实现的function称为UDF(user defined function)。

Flink支持如下四种UDF:

  • ScalarFunction: 类似于Flink算子的map,一对一转换。
  • TableFunction: 类似于flatmap,一对多。
  • AggregateFunction: 类似于reduce,多对一。通过聚合操作把多行输出为一个值。
  • TableAggregateFunction: 多对多。目前没发现如何在SQL中使用(官网给出了在Table API中的使用方法),暂不介绍。

编写注意事项

  • 编写UDF需要在项目中引入如下依赖。

    org.apache.flink
    flink-table-common
    ${flink.version}
    provided

  • UDF必须继承自ScalarFunction等基类。
  • UDF必须定义为public,不能为abstract。必须能被全局访问到。所以说不能包含非静态内部类或者匿名类。
  • 必须拥有默认构造函数(无参数构造函数)。使用Table API的时候可以支持使用有参数构造函数的UDF来构造有状态UDF。SQL模式建议使用无状态UDF。
  • UDF必须无状态,只能包含static字段和transient字段。

注册UDF

定义好的UDF在SQL使用之前,必须要注册。注册方法有如下两种。

使用Java API:

StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

// 注册UDF
// 创建UDF,无法覆盖已经存在的同名function。该function位于目前所在的catalog和database中(有命名空间)。全名为catalog_name.database_name.function_name
tEnv.createFunction("function_name", new MyFunction());
// 创建临时function,可以覆盖已存在的function,有命名空间
tEnv.createTemporaryFunction("function_name", new MyFunction());
// 创建临时系统function,可以覆盖已存在的function,位于全局,无命名空间概念
tEnv.createTemporarySystemFunction("function_name", new MyFunction());

使用SQL方式:

CREATE [TEMPORARY|TEMPORARY SYSTEM] FUNCTION
  [IF NOT EXISTS] [[catalog_name.]db_name.]function_name
  AS identifier [LANGUAGE JAVA|SCALA|PYTHON]

具体解释和Java API相同,不再赘述。

例如:

CREATE TEMPORARY SYSTEM FUNCTION changecase AS 'com.paultech.ChangeCaseTool';

注意:必须把UDF的jar包添加到Flink框架的classpath下(例如放置到$FLINK_HOME/lib中)。或者通过ADD JAR动态加载用户jar到classpath。参见https://nightlies.apache.org/flink/flink-docs-release-1.15/zh/docs/dev/table/sql/jar/。

结果计算

UDF可以按照实际需要,重写基类提供的open()close()isDeterministic()方法。

UDF的结果计算方法例如eval(), accumulate(), 或者 retract()方法,在运行阶段被动态生成的代码调用。

结果计算方法可以定义一个或者多个参数,可以使用重载方法,也可以使用变长参数。

类型推断

Flink Table API是强类型API,所有函数的参数类型和返回类型都必须映射为DataType。Flink支持自动类型推断和通过注解(@DataTypeHint@FunctionHint)指定类型。如果有更为复杂的类型推断逻辑,可以重写父类的getTypeInference方法。

自动类型推断

对于自动类型推断,Java数据类型和DataType类型对应关系参见https://nightlies.apache.org/flink/flink-docs-release-1.15/docs/dev/table/types/#data-type-extraction。

注解显式指定类型

@DataTypeHint可用于返回值,方法体(作用于返回值)和方法参数上,从而修改返回值或者式参数的DataType。

@DataTypeHint支持复杂类型,例如@DataTypeHint("ROW")

@FunctionHint适用于一个eval等结果计算方法可以接收多组类型不同的参数,返回值类型和接收参数类型相关的这种场景。我们贴出官网的例子:

import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;

// function with overloaded evaluation methods
// but globally defined output type
@FunctionHint(output = @DataTypeHint("ROW"))
public static class OverloadedFunction extends TableFunction {

  public void eval(int a, int b) {
    collect(Row.of("Sum", a + b));
  }

  // overloading of arguments is still possible
  public void eval() {
    collect(Row.of("Empty args", -1));
  }
}

// decouples the type inference from evaluation methods,
// the type inference is entirely determined by the function hints
@FunctionHint(
  input = {@DataTypeHint("INT"), @DataTypeHint("INT")},
  output = @DataTypeHint("INT")
)
@FunctionHint(
  input = {@DataTypeHint("BIGINT"), @DataTypeHint("BIGINT")},
  output = @DataTypeHint("BIGINT")
)
@FunctionHint(
  input = {},
  output = @DataTypeHint("BOOLEAN")
)
public static class OverloadedFunction extends TableFunction {

  // an implementer just needs to make sure that a method exists
  // that can be called by the JVM
  public void eval(Object... o) {
    if (o.length == 0) {
      collect(false);
    }
    collect(o[0]);
  }
}

自定义类型推断

如果注解无法描述类型推断逻辑,可以重写getTypeInference方法,使用代码实现复杂的类型推断逻辑。写法和参考官网的例子。

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.types.Row;

public static class LiteralFunction extends ScalarFunction {
  public Object eval(String s, String type) {
    switch (type) {
      case "INT":
        return Integer.valueOf(s);
      case "DOUBLE":
        return Double.valueOf(s);
      case "STRING":
      default:
        return s;
    }
  }

  // the automatic, reflection-based type inference is disabled and
  // replaced by the following logic
  @Override
  public TypeInference getTypeInference(DataTypeFactory typeFactory) {
    return TypeInference.newBuilder()
      // specify typed arguments
      // parameters will be casted implicitly to those types if necessary
      .typedArguments(DataTypes.STRING(), DataTypes.STRING())
      // specify a strategy for the result data type of the function
      .outputTypeStrategy(callContext -> {
        if (!callContext.isArgumentLiteral(1) || callContext.isArgumentNull(1)) {
          throw callContext.newValidationError("Literal expected for second argument.");
        }
        // return a data type based on a literal
        final String literal = callContext.getArgumentValue(1, String.class).orElse("STRING");
        switch (literal) {
          case "INT":
            return Optional.of(DataTypes.INT().notNull());
          case "DOUBLE":
            return Optional.of(DataTypes.DOUBLE().notNull());
          case "STRING":
          default:
            return Optional.of(DataTypes.STRING());
        }
      })
      .build();
  }
}

确定性

如果UDF不能返回确定的结果(例如random(), date()now()),必须重写isDeterministic()并返回false。这涉及到执行计划优化过程。

如果UDF的isDeterministic()返回true,并且传入的参数全都是常量,在planning阶段该UDF的值会被预先计算出来。例如SELECT ABS(-1)会优化为SELECT 1。但是SELECT ABS(field) FROM t不会优化,因为field不是常量。

如果UDF的isDeterministic()返回false,或者传入的参数存在变量,UDF的值在执行阶段才会被计算出来。

open和close方法

openclose方法可用于编写自定义的初始化和清理逻辑。open方法的执行时机早于eval等结果计算方法。

可参考官网的例子(https://nightlies.apache.org/flink/flink-docs-release-1.15/docs/dev/table/functions/udfs/#runtime-integration)。这个例子在启动作业的时候加载job parameter。

import org.apache.flink.table.api.*;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.ScalarFunction;

public static class HashCodeFunction extends ScalarFunction {

    private int factor = 0;

    @Override
    public void open(FunctionContext context) throws Exception {
        // access the global "hashcode_factor" parameter
        // "12" would be the default value if the parameter does not exist
        factor = Integer.parseInt(context.getJobParameter("hashcode_factor", "12"));
    }

    public int eval(String s) {
        return s.hashCode() * factor;
    }
}

TableEnvironment env = TableEnvironment.create(...);

// add job parameter
env.getConfig().addJobParameter("hashcode_factor", "31");

// register the function
env.createTemporarySystemFunction("hashCode", HashCodeFunction.class);

// use the function
env.sqlQuery("SELECT myField, hashCode(myField) FROM MyTable");

ScalarFunction

直接用例子说明用法。我们编写一个大小写转换UDF。可以接收多个参数。默认将string转换为大写。或者是通过boolean指定转换为大写还是小写。

public static class ChangeCaseTool extends ScalarFunction {
    public String eval(String s) {
        return s.toUpperCase(Locale.ROOT);
    }
    public String eval(String s, Boolean changeToUppercase) {
        if (changeToUppercase) {
            return s.toUpperCase(Locale.ROOT);
        } else {
            return s.toLowerCase(Locale.ROOT);
        }
    }
}

配合如下例子使用:

StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

// 测试表数据
DataStreamSource streamSource = env.fromElements(
        Row.of("hello#world", 100),
        Row.of("hola#hastaLAvista", 50));

// 转换DataStream为Table,并指定字段名
Table table = tEnv.fromDataStream(streamSource).as("name", "value");

// 将Table映射为demo表
tEnv.createTemporaryView("demo", table);

// 注册UDF
tEnv.createTemporaryFunction("changecase", new ChangeCaseTool());

// 执行SQL时候调用UDF
tEnv.executeSql("select changecase(`name`, true) as `name`, `value` from demo").print();

输出如下:

+----+--------------------------------+-------------+
| op |                           name |       value |
+----+--------------------------------+-------------+
| +I |                    HELLO#WORLD |         100 |
| +I |              HOLA#HASTALAVISTA |          50 |
+----+--------------------------------+-------------+

TableFunction

TableFunction将一个字段拆分为多列。

同样直接以例子说明。例子中的方法将内容按照指定的delimiter拆分,然后获取拆分后第一个和第二个字符串,列明分别为word1和word2。

@FunctionHint(output = @DataTypeHint("ROW"))
public static class StringSplitter extends TableFunction {
    public void eval(String s, String delimiter) {
        String[] split = s.split(delimiter);
        if (split.length >= 2) {
            collect(Row.of(split[0], split[1]));
        } else if (split.length == 1) {
            collect(Row.of(split[0], null));
        }
    }
}

由于TableFunction的计算结果是一个伪表,我们对它进行操作的时候(例如join)需要使用LATERAL TABLE(function(field))或者LATERAL TABLE(function(field)) AS T(NEW_FIELD_NAME1, NEW_FIELD_NAME2)(修改字段名)把UDF计算结果作为表来使用。

例子如下:

tEnv.createTemporaryFunction("split", new StringSplitter());
tEnv.executeSql("select * from demo, lateral table(split(`name`, '#'))").print();

结果如下:

+----+--------------------------------+-------------+--------------------------------+--------------------------------+
| op |                           name |       value |                          word1 |                          word2 |
+----+--------------------------------+-------------+--------------------------------+--------------------------------+
| +I |                    hello#world |         100 |                          hello |                          world |
| +I |              hola#hastaLAvista |          50 |                           hola |                   hastaLAvista |
+----+--------------------------------+-------------+--------------------------------+--------------------------------+

AggregateFunction

用一个例子说明。编写一个自定义聚合函数MyAvg,根据物品单价和数量,求单价的平均值。UDF代码如下:

// 自定义聚合器,持有总价和数量,以便于计算平均值
public static class MyAvgAggregator {
    public double sum;
    public int count;
}

// AggregateFunction需要声明聚合结果数据类型和自定义聚合器类型
public static class MyAvg extends AggregateFunction {

    // 获取计算结果的方法
    @Override
    public Double getValue(MyAvgAggregator accumulator) {
        return accumulator.sum / accumulator.count;
    }

    // 创建自定义聚合器
    @Override
    public MyAvgAggregator createAccumulator() {
        return new MyAvgAggregator();
    }

    // 聚合方法(必须),将数据加入到聚合器
    public void accumulate(MyAvgAggregator acc, Double unit, Integer count) {
        acc.sum += unit * count;
        acc.count += count;
    }

    // 撤回方法(可选),假设数据已经添加进自定义聚合器。该方法指定了将数据从自定义聚合器减去的逻辑。
    // 对于unbounded tables进行bounded OVER 聚合运算,必须提供此方法(需要减去over window旧的聚合数据,添加新的数据后重新计算聚合结果)
    public void retract(MyAvgAggregator acc, Double unit, Integer count) {
        acc.sum -= unit * count;
        acc.count -= count;
    }

    // 合并方法(可选),包含合并多个自定义聚合器的逻辑
    // 对于unbounded session window grouping聚合和bounded grouping聚合,必须提供此方法
    public void merge(MyAvgAggregator acc, Iterable it) {
        for (MyAvgAggregator a : it) {
            acc.count += a.count;
            acc.sum += a.sum;
        }
    }
}

例子如下:

StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setRuntimeMode(RuntimeExecutionMode.BATCH);
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

DataStreamSource streamSource = env.fromElements(
        Row.of("Apple", 4.0, 50),
        Row.of("Banana", 7.5, 20),
        Row.of("Peach", 8.0, 15)
);

Table table = tEnv.fromDataStream(streamSource).as("name", "unit_price", "count");

tEnv.createTemporaryView("demo", table);

tEnv.createTemporaryFunction("myavg", new MyAvg());

tEnv.executeSql("select myavg(`unit_price`, `count`) as avg_unit_price from demo").print();

执行结果:

+----+--------------------------------+
| op |                 avg_unit_price |
+----+--------------------------------+
| +I |              5.529411764705882 |
+----+--------------------------------+

当然,对于金额运算结果,我们可以让MyAvg返回BigDecimal类型(其实可以使用ROUND函数解决。这里我们演示下上面介绍的@FunctionHint注解用法)。我们改写MyAvg如下:

// 增加注解,声明输入和输出的数据类型
@FunctionHint(input = {@DataTypeHint("DOUBLE"), @DataTypeHint("INT")}, output = @DataTypeHint("DECIMAL(12, 2)"))
public static class MyAvg extends AggregateFunction {
    @Override
    public BigDecimal getValue(MyAvgAggregator accumulator) {
        return BigDecimal.valueOf(accumulator.sum).divide(BigDecimal.valueOf(accumulator.count), 2, RoundingMode.HALF_DOWN);
    }

    // 其余方法完全相同,此处省略
    // ...
}

查询SQL修改为:

tEnv.executeSql("select cast(myavg(`unit_price`, `count`) as DECIMAL(12, 2)) as avg_unit_price from demo").print();

结果如下:

+----+----------------+
| op | avg_unit_price |
+----+----------------+
| +I |           5.53 |
+----+----------------+

参考文档

https://nightlies.apache.org/flink/flink-docs-release-1.15/docs/dev/table/functions/udfs/#user-defined-functions

https://nightlies.apache.org/flink/flink-docs-release-1.15/docs/dev/table/types/

你可能感兴趣的:(Flink 使用之 SQL UDF)