作者:徐嘉 StarRocks Active Contributor
StarRocks 函数就像预设于数据库中的公式,允许用户调用现有的函数以完成特定功能。函数可以很方便地实现业务逻辑的重用,因此正确使用函数会让读者在编写 SQL 语句时起到事半功倍的效果。
StarRocks 提供了多种内置函数,包括标量函数、聚合函数、窗口函数、Table 函数和 Lambda 函数等,可帮助用户更加便捷地处理表中的数据。此外,StarRocks 还允许用户自定义函数以适应实际的业务操作。本文将以标量函数和聚合函数为例,介绍 StarRocks 常见的两种函数实现原理,希望读者能够借鉴其设计思路,并按需实现所需的函数。同时,我们也欢迎社区小伙伴一起贡献力量,共同完善 StarRocks 的功能,具体的函数任务认领方式请见文末。
标量函数用于处理单行数据,接受一个或多个参数作为输入,并返回一个值作为结果。StarRocks 常见的标量函数有 abs、floor、ceil 等。
首先,我们来了解函数签名,函数签名用来唯一标识函数,描述函数的 ID、名字、返回类型、输入参数的类型等基本信息。标量函数的函数签名定义在gensrc/script/f
unctions.py
,在编译阶段我们会根据 Python 文件中的内容生成对应的 Java 和 C++ 代码,供 FE 和 BE 使用。
每个函数签名在 Python 文件中通过一个特定的数组来描述,数组的内容有如下两种格式:
[, , , [...], ] or [, , , [...], , , ]
其中基本信息如下:
function_id:函数唯一标识,是唯一一串数字,function_id 遵循如下约定,前两位表示 function_type,中间两位表示 function_group,余下的表示具体的 sub_function,后面我们会举例说明
function_name:函数名称
return_type:返回值类型
arg_type:入参类型,如果有多个入参,需要在数组中描述每个入参的类型
be_scalar_function:BE 中负责实现该函数计算逻辑的函数
be_prepare_function/be_close_function:可选参数,有些函数在执行的过程中可能会传递一些状态,be_prepare_function 和 be_close_function 就是 BE 中负责实现创建状态和回收状态的函数
为了支持多种数据类型作为输入,需要为每种类型单独创建函数签名。以下以 abs 函数为例,该函数用于计算绝对值,需要描述以下五个信息:
function_id:1
0
代表它们都属于 math function,04
代表它们都属于 abs 这个 function group,余下的数字用来区分具体的 sub-function
function_name:函数名称都是 abs
return_type:返回值类型,同入参类型一致
arg_type:该函数只接受一个入参,所以第四项的数组中只有一个元素。
be_eval_function:BE 中实现计算逻辑的函数,StarRocks 针对每种数据类型做了特殊处理,所以每个签名中的函数名也不一样
对于 abs 函数而言,由于不需要传递状态,因此不需要 be_prepare_function 和 be_close_function 这两个选项。请注意,这两个选项在某些情况下可能会用到,具体用法将在后面的示例中介绍。
[10040, "abs", "DOUBLE", ["DOUBLE"], "MathFunctions::abs_double"],
[10041, "abs", "FLOAT", ["FLOAT"], "MathFunctions::abs_float"],
[10042, "abs", "LARGEINT", ["LARGEINT"], "MathFunctions::abs_largeint"],
[10043, "abs", "LARGEINT", ["BIGINT"], "MathFunctions::abs_bigint"],
[10044, "abs", "BIGINT", ["INT"], "MathFunctions::abs_int"],
[10045, "abs", "INT", ["SMALLINT"], "MathFunctions::abs_smallint"],
[10046, "abs", "SMALLINT", ["TINYINT"], "MathFunctions::abs_tinyint"],
[10047, "abs", "DECIMALV2", ["DECIMALV2"], "MathFunctions::abs_decimalv2val"],
[100470, "abs", "DECIMAL32", ["DECIMAL32"], "MathFunctions::abs_decimal32"],
[100471, "abs", "DECIMAL64", ["DECIMAL64"], "MathFunctions::abs_decimal64"],
[100472, "abs", "DECIMAL128", ["DECIMAL128"], "MathFunctions::abs_decimal128"],
在 StarRocks 的编译和执行阶段,都会使用函数签名来确定函数的输入输出和执行逻辑。具体流程如下:
在编译阶段,根据gensrc/script/functions.py
中的内容生成代码供FE和BE使用。
Java 代码在fe/fe-core/target/generated-sources/build/com/starrocks/builtins/VectorizedBuiltinFunctions.java
,FunctionSet[1] 保存了所有的函数签名,初始化阶段会调用VectorizedBuiltinFunctions::initBuiltins
来添加标量函数的函数签名。SQL analyze 阶段,会利用 FunctionSet 提供的信息进行校验,如果找不到函数签名会直接返回错误,这部分实现在 ExpressionAnalyzer.Visitor [2]的 visitFunctionCall[3] 方法中。
C++ 代码在./gensrc/build/gen_C++/opcode/builtin_function
s.cpp
,BE 标量函数的函数签名保存在BuiltinFunctions::_fn_tables
[4],生成的代码用于初始化_fn_tables。在 SQL 执行阶段,VectorizedFunctionCallExpr 会根据 fid(函数唯一标识)从 _fn_tables 中找到执行该函数所需要的信息,包括输入参数的个数,执行函数的函数指针(ScalarFunction),以及执行前后的 PrepareFunction 和 CloseFunction,这部分定义在 FunctionDescriptor[5]。
这部分此处不做赘述,根据函数的功能实现相关的逻辑即可。
接下来我们以 sha2 函数为例,介绍引入新函数的具体流程。sha2 函数的功能如下图,其详细信息可以参考官方文档[6]中的介绍。
首先,需要在gensrc/script/functions.py
中新增签名。
[120160, "sha2", "VARCHAR", ["VARCHAR", "INT"], "EncryptionFunctions::sha2", "EncryptionFunctions::sha2_prepare", "EncryptionFunctions::sha2_close"],
如上述代码所示,sha2 函数输入需要两个参数,根据第二个参数来决定使用哪种加密算法,如果第二个参数本身是个常数,那么不需要每次执行的时候都去判断。我们可以把这部分“状态”保存起来,所以函数签名中除了前文所述的五个基本信息之外,还增加了 EncryptionFunctions::sha2_prepare
和 EncryptionFunctions::sha2_close
,用来实现状态的创建和回收。
sha2 属于加密函数的一种,所以我们直接在 EncryptionFunctions [7]中增加相应的方法即可。具体代码如下:
/*
* Called by sha2 to the corresponding part
*/
DEFINE_VECTORIZED_FN(sha224);
DEFINE_VECTORIZED_FN(sha256);
DEFINE_VECTORIZED_FN(sha384);
DEFINE_VECTORIZED_FN(sha512);
DEFINE_VECTORIZED_FN(invalid_sha);
/**
* @param: [json_string, tagged_value]
* @paramType: [BinaryColumn, BinaryColumn]
* @return: Int32Column
*/
DEFINE_VECTORIZED_FN(sha2);
static Status sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope);
static Status sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope);
其中,实现标量函数的计算逻辑主要分布在 PrepareFuntion、ScalarFunction、CloseFunction 三个函数中。
Prepare 阶段主要是针对第二个参数进行特殊处理,如果是常数,可以把实现对应加密算法的函数指针保存起来,后面的 ScalarFunction 中可以直接调用。加密算法的函数指针保存在 EncryptionFunctions::SHA2Ctx
中,通过 FunctionContext::set_function_state
保存在上下文中。具体代码如下:
Status EncryptionFunctions::sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::FRAGMENT_LOCAL) {
return Status::OK();
}
if (!context->is_notnull_constant_column(1)) {
return Status::OK();
}
ColumnPtr column = context->get_constant_column(1);
auto hash_length = ColumnHelper::get_const_value(column);
ScalarFunction function;
if (hash_length == 224) {
function = &EncryptionFunctions::sha224;
} else if (hash_length == 256 || hash_length == 0) {
function = &EncryptionFunctions::sha256;
} else if (hash_length == 384) {
function = &EncryptionFunctions::sha384;
} else if (hash_length == 512) {
function = &EncryptionFunctions::sha512;
} else {
function = EncryptionFunctions::invalid_sha;
}
auto fc = new EncryptionFunctions::SHA2Ctx();
fc->function = function;
context->set_function_state(scope, fc);
return Status::OK();
}
ScalarFunction 主要实现 sha2 的计算逻辑,如果第二个参数是常数,那么 PrepareFunction 中保存的 function_state
就可以派上用场了。具体代码如下:
StatusOr EncryptionFunctions::sha2(FunctionContext* ctx, const Columns& columns) {
if (!ctx->is_notnull_constant_column(1)) {
auto src_viewer = ColumnViewer(columns[0]);
auto length_viewer = ColumnViewer(columns[1]);
auto size = columns[0]->size();
ColumnBuilder result(size);
for (int row = 0; row < size; row++) {
if (src_viewer.is_null(row) || length_viewer.is_null(row)) {
result.append_null();
continue;
}
auto src_value = src_viewer.value(row);
auto length = length_viewer.value(row);
if (length == 224) {
SHA224Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else if (length == 0 || length == 256) {
SHA256Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else if (length == 384) {
SHA384Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else if (length == 512) {
SHA512Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else {
result.append_null();
}
}
return result.build(ColumnHelper::is_all_const(columns));
}
auto ctc = reinterpret_cast(ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
return ctc->function(ctx, columns);
}
CloseFunction 主要用来回收资源。函数执行中所依赖的 function state,在执行结束之后不再被需要,那么可以在这个阶段释放内存。具体代码如下:
Status EncryptionFunctions::sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope) { if (scope == FunctionContext::FRAGMENT_LOCAL) { auto fc = reinterpret_cast
具体细节可参考
EntryptionFunctionTest
[8] 即可。代码示例如下:
TEST_P(ShaTestFixture, test_sha2) {
auto [str, len, expected] = GetParam();
std::unique_ptr ctx(FunctionContext::create_test_context());
Columns columns;
auto plain = BinaryColumn::create();
plain->append(str);
ColumnPtr hash_length =
len == -1 ? ColumnHelper::create_const_null_column(1) : ColumnHelper::create_const_column(len, 1);
if (str == "NULL") {
columns.emplace_back(ColumnHelper::create_const_null_column(1));
} else {
columns.emplace_back(plain);
}
columns.emplace_back(hash_length);
ctx->set_constant_columns(columns);
ASSERT_TRUE(EncryptionFunctions::sha2_prepare(ctx.get(), FunctionContext::FunctionStateScope::FRAGMENT_LOCAL).ok());
if (len != -1) {
ASSERT_NE(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
} else {
ASSERT_EQ(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
}
ColumnPtr result = EncryptionFunctions::sha2(ctx.get(), columns).value();
if (expected == "NULL") {
std::cerr << result->debug_string() << std::endl;
EXPECT_TRUE(result->is_null(0));
} else {
auto v = ColumnHelper::cast_to(result);
EXPECT_EQ(expected, v->get_data()[0].to_string());
}
ASSERT_TRUE(EncryptionFunctions::sha2_close(ctx.get(),
FunctionContext::FunctionContext::FunctionStateScope::FRAGMENT_LOCAL)
.ok());
}
完整的改动可以参考 PR:https://github.com/StarRocks/starrocks/pull/1264/files。
聚合函数用于处理多行数据,接受多行数据作为输入,经过计算后返回一行结果。StarRocks 常见的聚合函数有 count、sum、avg、min、max 等。
在查询执行阶段,Pipeline 引擎的聚合算子通过 Aggregator 完成聚合计算,聚合算子的实现原理可参见文末《StarRocks 聚合算子源码解析》[9],本文主要关注聚合函数的实现原理。
Aggregator 在 prepare 阶段会根据函数名找到对应的 AggregateFunction 并保存下来,AggregateFunction 是最重要的抽象,封装了聚合计算过程中需要的各个接口,每个聚合函数都需要继承 AggregateFunction 实现自己的逻辑。计算的中间结果保存在 AggDataPtr 中,AggDataPrt 是一个指针,指向描述中间结果的数据结构。每种聚合函数的中间结果都不相同,比如求和函数,只需要保存 sum 即可,而平均值函数,除了保存 sum 之外,还需要记录 count。
在 AggregateFunction 提供的接口中,我们需要重点关注以下几个:
// 逐行读取数据,不断更新 state 中保存的中间结果。
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, size_t row_num)
// 通常用在多阶段聚合中,读取已经算好的部分中间结果,合并计算,更新 state 中的数据。
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num)
// 多阶段的聚合可能会通过多个节点执行,计算的中间结果需要跨网络传输,这个方法用来实现序列化的逻辑。
void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)
// 把中间结果转成最终对用户返回的结果。比如求和函数,直接返回中间结果保存的 sum 即可,而平均值函数,需要返回 sum/count。
void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)
// 重置 state 的状态,比如在 window aggregate 中,我们会用一个的 state 保存中间结果,每次遇到新的 group时,需要通过 reset 重置,然后才能进行接下来的计算。
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state)
除了上述内容之外,为了减少函数调用的开销,AggregateFunction 还封装了批量操作的接口,具体的细节这里就不展开讲解了,可以参考be/src/exprs/agg/aggregate.h
。
接下来我们以 ANY_VALUE 为例,介绍添加聚合函数的流程,这个函数实现的功能比较简单,可以参考官方文档[10]说明:
FE 通过 AggregateFunction[11] 来描述聚合函数,所有的聚合函数都会注册在 FunctionSet 中,初始化阶段在FunctionSet的initAggregateBuiltins [12]方法内增加对应的函数即可。具体代码如下:
// ANY_VALUE
addBuiltin(AggregateFunction.createBuiltin(ANY_VALUE,
Lists.newArrayList(t), t, t, true, false, false));
此处重点是如何描述中间结果,以及如何实现 AggregateFunction 的核心接口。
ANY_VALUE
的语义很简单,在每个 group 中选择一行返回。中间结果通过 AnyValueAggregateData
描述,只需要记录当前是否已经有结果以及对应的数据是什么即可,AnyValueAggregateData
为每种数据类型进行了特化,实现上几乎一致。具体代码如下:
template
struct AnyValueAggregateData {
using T = AggDataValueType;
T result;
bool has_value = false;
void reset() {
result = T{};
has_value = false;
}
};
具体的计算逻辑非常简单,这部分通过 AnyValueElement 实现。具体代码如下:
template
struct AnyValueElement {
using RefType = AggDataRefType;
void operator()(State& state, RefType right) const {
if (UNLIKELY(!state.has_value)) {
AggDataTypeTraits::assign_value(state.result, right);
state.has_value = true;
}
}
};
最后利用 AnyValueElement 实现 AggregateFunction 所需要的接口即可,具体代码如下:
template , typename = guard::Guard>
class AnyValueAggregateFunction final
: public AggregateFunctionBatchHelper> {
public:
using InputColumnType = RunTimeColumnType;
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override {
this->data(state).reset();
}
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
size_t row_num) const override {
DCHECK(!columns[0]->is_nullable());
const auto& column = down_cast(*columns[0]);
OP()(this->data(state), AggDataTypeTraits::get_row_ref(column, row_num));
}
void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,
AggDataPtr __restrict state) const override {
update(ctx, columns, state, 0);
}
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
DCHECK(!column->is_nullable());
const auto& input_column = down_cast(*column);
OP()(this->data(state), AggDataTypeTraits::get_row_ref(input_column, row_num));
}
void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
DCHECK(!to->is_nullable());
AggDataTypeTraits::append_value(down_cast(to), this->data(state).result);
}
void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size,
ColumnPtr* dst) const override {
*dst = src[0];
}
void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
DCHECK(!to->is_nullable());
AggDataTypeTraits::append_value(down_cast(to), this->data(state).result);
}
void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start,
size_t end) const override {
DCHECK_GT(end, start);
InputColumnType* column = down_cast(dst);
for (size_t i = start; i < end; ++i) {
AggDataTypeTraits::append_value(column, this->data(state).result);
}
}
std::string get_name() const override { return "any_value"; }
};
完整的实现细节参见:be/src/exprs/agg/any_value.h
这一步是为了让 AggregateFactory 可以根据函数名找到对应的函数,函数的创建通过MakeAnyValueAggregateFunction
实现,相关的改动可以在 aggregate_factory.hpp
[13] 中 grep MakeAnyValueAggregateFunction
看到,比较简单,这里不再过多赘述,具体示例如下:
template
AggregateFunctionPtr AggregateFactory::MakeAnyValueAggregateFunction() {
return std::make_shared<
AnyValueAggregateFunction, AnyValueElement>>>();
}
可以参见 test/exprs/agg/aggregate_test.
cpp
[14]添加单测,比如:
TEST_F(AggregateTest, test_any_value) {
const AggregateFunction* func = get_aggregate_function("any_value", TYPE_SMALLINT, TYPE_SMALLINT, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_INT, TYPE_INT, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_BIGINT, TYPE_BIGINT, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_LARGEINT, TYPE_LARGEINT, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_FLOAT, TYPE_FLOAT, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_DOUBLE, TYPE_DOUBLE, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_VARCHAR, TYPE_VARCHAR, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_DECIMALV2, TYPE_DECIMALV2, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_DATETIME, TYPE_DATETIME, false);
test_non_deterministic_agg_function(ctx, func);
func = get_aggregate_function("any_value", TYPE_DATE, TYPE_DATE, false);
test_non_deterministic_agg_function(ctx, func);
}
完整的改动见 PR:https://github.com/StarRocks/starrocks/pull/2073
本文介绍了 StarRocks 中标量函数和聚合函数的实现原理,并以 sha2 标量函数和 ANY_VALUE 聚合函数为例,说明了如何添加标量函数和新增聚合函数。
标量函数定义在 be/src/exprs/
目录下。若想查看某个函数的实现,可以在函数签名中找到对应的 be function,然后在该目录下使用 grep 进行查找。
此外,StarRocks 还实现了多种聚合函数,具体实现可在 be/src/exprs/agg
目录下查找。
最后,如果你在阅读完本文后对 StarRocks 函数的实现原理以及如何添加新的函数还有很多疑问,欢迎报名参加 4/6(星期四)的
[1]FunctionSet:https://github.com/StarRocks/starrocks/blob/main/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java
[2] ExpressionAnalyzer.Visitor:https://github.com/StarRocks/starrocks/blob/main/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/ExpressionAnalyzer.java#L303 [3]visitFunctionCall :https://github.com/StarRocks/starrocks/blob/main/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/ExpressionAnalyzer.java#L893 [4]BuiltinFunctions::_fn_tables:https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/builtin_functions.h#L75
[5]FunctionDescriptor:https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/builtin_functions.h#L32 [6]sha2 函数:https://docs.starrocks.io/zh-cn/latest/sql-reference/sql-functions/crytographic-functions/sha2#%E5%8A%9F%E8%83%BD [7]EncryptionFunctions:
https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/encryption_functions.h [8]EntryptionFunctionTest:
https://github.com/StarRocks/starrocks/blob/main/be/test/exprs/encryption_functions_test.cpp [9]《StarRocks 聚合算子源码解析》:https://zhuanlan.zhihu.com/p/592058276 [10]ANY_VALUE 功能:https://docs.starrocks.io/zh-cn/latest/sql-reference/sql-functions/aggregate-functions/any_value#%E5%8A%9F%E8%83%BD
[11]AggregateFunction:https://github.com/StarRocks/starrocks/blob/main/fe/fe-core/src/main/java/com/starrocks/catalog/AggregateFunction.java#L61
[12]initAggregateBuiltins:https://github.com/StarRocks/starrocks/blob/main/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java#L742
[13]aggregate_factory.cpp:https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/agg/factory/aggregate_factory.hpp
[14]aggregate_test:https://github.com/StarRocks/starrocks/blob/main/be/test/exprs/agg/aggregate_test.cpp#L1667